From 1006e8092f13af98bda8f571cf5a1328f8ccd4cd Mon Sep 17 00:00:00 2001 From: SomberNight Date: Wed, 17 Dec 2025 15:16:05 +0000 Subject: [PATCH] lnworker: split LNWallet and LNWorker: LNWallet "has an" LNWorker - LNWallet no longer "is-an" LNWorker, instead LNWallet "has-an" LNWorker - the motivation is to make the unit tests nicer, and allow writing unit tests for more things - I hope this makes it possible to e.g. test lnsweep in the unit tests - some stuff we would previously have to write a regtest for, maybe we can write a unit test for, now - in unit tests, MockLNWallet now - inherits LNWallet - the Wallet is no longer being mocked --- electrum/commands.py | 10 +- electrum/gui/qml/qechanneldetails.py | 2 +- electrum/gui/qml/qechannellistmodel.py | 2 +- electrum/gui/qml/qeinvoice.py | 2 +- electrum/gui/qml/qewallet.py | 2 +- electrum/gui/qt/channels_list.py | 2 +- electrum/gui/qt/lightning_dialog.py | 2 +- electrum/lnchannel.py | 10 +- electrum/lnpeer.py | 8 +- electrum/lnworker.py | 282 ++++++++++++++-------- electrum/onion_message.py | 24 +- electrum/scripts/ln_features.py | 2 +- electrum/wallet.py | 4 +- tests/test_commands.py | 7 +- tests/test_lnchannel.py | 6 +- tests/test_lnpeer.py | 313 ++++++++----------------- tests/test_onion_message.py | 21 +- 17 files changed, 345 insertions(+), 354 deletions(-) diff --git a/electrum/commands.py b/electrum/commands.py index 4a60e8b3e..45e3d92c8 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -1687,7 +1687,7 @@ class Commands(Logger): arg:int:timeout:Timeout in seconds (default=20) """ lnworker = self.network.lngossip if gossip else wallet.lnworker - peer = await lnworker.add_peer(connection_string) + peer = await lnworker.lnpeermgr.add_peer(connection_string) try: await util.wait_for2(peer.initialized, timeout=LN_P2P_NETWORK_TIMEOUT) except (CancelledError, Exception) as e: @@ -1700,7 +1700,7 @@ class Commands(Logger): """Display statistics about lightninig gossip""" lngossip = self.network.lngossip channel_db = lngossip.channel_db - forwarded = dict([(key.hex(), p._num_gossip_messages_forwarded) for key, p in wallet.lnworker.peers.items()]), + forwarded = dict([(key.hex(), p._num_gossip_messages_forwarded) for key, p in wallet.lnworker.lnpeermgr.peers.items()]), out = { 'received': { 'channel_announcements': lngossip._num_chan_ann, @@ -1731,7 +1731,7 @@ class Commands(Logger): 'initialized': p.is_initialized(), 'features': str(LnFeatures(p.features)), 'channels': [c.funding_outpoint.to_str() for c in p.channels.values()], - } for p in lnworker.peers.values()] + } for p in lnworker.lnpeermgr.peers.values()] @command('wpnl') async def open_channel(self, connection_string, amount, push_amount=0, public=False, zeroconf=False, password=None, wallet: Abstract_Wallet = None): @@ -1748,7 +1748,7 @@ class Commands(Logger): raise UserFacingException("This wallet cannot create new channels") funding_sat = satoshis(amount) push_sat = satoshis(push_amount) - peer = await wallet.lnworker.add_peer(connection_string) + peer = await wallet.lnworker.lnpeermgr.add_peer(connection_string) chan, funding_tx = await wallet.lnworker.open_channel_with_peer( peer, funding_sat, push_sat=push_sat, @@ -2197,7 +2197,7 @@ class Commands(Logger): pubkey = bfh(node_id) assert len(pubkey) == 33, 'invalid node_id' - peer = wallet.lnworker.peers[pubkey] + peer = wallet.lnworker.lnpeermgr.peers[pubkey] assert peer, 'node_id not a peer' path = [pubkey, wallet.lnworker.node_keypair.pubkey] diff --git a/electrum/gui/qml/qechanneldetails.py b/electrum/gui/qml/qechanneldetails.py index e6f59defa..9ac86b7e4 100644 --- a/electrum/gui/qml/qechanneldetails.py +++ b/electrum/gui/qml/qechanneldetails.py @@ -94,7 +94,7 @@ class QEChannelDetails(AuthMixin, QObject, QtEventListener): def name(self) -> str: if not self._channel: return '' - return self._wallet.wallet.lnworker.get_node_alias(self._channel.node_id) or '' + return self._wallet.wallet.lnworker.lnpeermgr.get_node_alias(self._channel.node_id) or '' @pyqtProperty(str, notify=channelChanged) def pubkey(self) -> str: diff --git a/electrum/gui/qml/qechannellistmodel.py b/electrum/gui/qml/qechannellistmodel.py index b3bf4586e..da6be7bb2 100644 --- a/electrum/gui/qml/qechannellistmodel.py +++ b/electrum/gui/qml/qechannellistmodel.py @@ -88,7 +88,7 @@ class QEChannelListModel(QAbstractListModel, QtEventListener): item = { 'cid': lnc.channel_id.hex(), 'node_id': lnc.node_id.hex(), - 'node_alias': lnworker.get_node_alias(lnc.node_id) or '', + 'node_alias': lnworker.lnpeermgr.get_node_alias(lnc.node_id) or '', 'short_cid': lnc.short_id_for_GUI(), 'state': lnc.get_state_for_GUI(), 'state_code': int(lnc.get_state()), diff --git a/electrum/gui/qml/qeinvoice.py b/electrum/gui/qml/qeinvoice.py index c3a9a8f33..640363fdf 100644 --- a/electrum/gui/qml/qeinvoice.py +++ b/electrum/gui/qml/qeinvoice.py @@ -258,7 +258,7 @@ class QEInvoice(QObject, QtEventListener): def name_for_node_id(self, node_id): lnworker = self._wallet.wallet.lnworker - return (lnworker.get_node_alias(node_id) if lnworker else None) or node_id.hex() + return (lnworker.lnpeermgr.get_node_alias(node_id) if lnworker else None) or node_id.hex() def set_effective_invoice(self, invoice: Invoice): self._effectiveInvoice = invoice diff --git a/electrum/gui/qml/qewallet.py b/electrum/gui/qml/qewallet.py index b9c593db6..f08005baf 100644 --- a/electrum/gui/qml/qewallet.py +++ b/electrum/gui/qml/qewallet.py @@ -525,7 +525,7 @@ class QEWallet(AuthMixin, QObject, QtEventListener): @pyqtProperty(int, notify=peersUpdated) def lightningNumPeers(self): if self.isLightning: - return self.wallet.lnworker.num_peers() + return self.wallet.lnworker.lnpeermgr.num_peers() return 0 @pyqtSlot() diff --git a/electrum/gui/qt/channels_list.py b/electrum/gui/qt/channels_list.py index c62d49c5a..f8c4eb484 100644 --- a/electrum/gui/qt/channels_list.py +++ b/electrum/gui/qt/channels_list.py @@ -98,7 +98,7 @@ class ChannelsList(MyTreeView): labels[subject] = label status = chan.get_state_for_GUI() closed = chan.is_closed() - node_alias = self.lnworker.get_node_alias(chan.node_id) or chan.node_id.hex() + node_alias = self.lnworker.lnpeermgr.get_node_alias(chan.node_id) or chan.node_id.hex() capacity_str = self.main_window.format_amount(chan.get_capacity(), whitespaces=True) return { self.Columns.SHORT_CHANID: chan.short_id_for_GUI(), diff --git a/electrum/gui/qt/lightning_dialog.py b/electrum/gui/qt/lightning_dialog.py index 4c6fb37f3..ec9ea9f06 100644 --- a/electrum/gui/qt/lightning_dialog.py +++ b/electrum/gui/qt/lightning_dialog.py @@ -62,7 +62,7 @@ class LightningDialog(QDialog, QtEventListener): self.register_callbacks() self.network.channel_db.update_counts() # trigger callback if self.network.lngossip: - self.on_event_gossip_peers(self.network.lngossip.num_peers()) + self.on_event_gossip_peers(self.network.lngossip.lnpeermgr.num_peers()) self.on_event_unknown_channels(len(self.network.lngossip.unknown_ids)) else: self.num_peers.setText(_('Lightning gossip not active.')) diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index 683f44a62..f4d19dd9c 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -417,9 +417,9 @@ class AbstractChannel(Logger, ABC): if not self.is_funding_tx_mined(funding_height): # funding tx is invalid (invalid amount or address) we need to get rid of the channel again self.should_request_force_close = True - if self.lnworker and self.node_id in self.lnworker.peers: + if self.lnworker and (peer := self.lnworker.lnpeermgr.get_peer_by_pubkey(self.node_id)): # reconnect to trigger force close request - self.lnworker.peers[self.node_id].close_and_cleanup() + peer.close_and_cleanup() else: # remove zeroconf flag as we are now confirmed, this is to prevent an electrum server causing # us to remove a channel later in update_unfunded_state by omitting its funding tx @@ -779,7 +779,7 @@ class Channel(AbstractChannel): self, state: 'StoredDict', *, name=None, - lnworker=None, # None only in unittests + lnworker: 'LNWallet' = None, # None only in unittests initial_feerate=None, jit_opening_fee: Optional[int] = None, ): @@ -1022,8 +1022,8 @@ class Channel(AbstractChannel): elif self.is_static_remotekey_enabled(): our_payment_pubkey = self.config[LOCAL].payment_basepoint.pubkey addr = make_commitment_output_to_remote_address(our_payment_pubkey, has_anchors=self.has_anchors()) - if self.lnworker: - assert self.lnworker.wallet.is_mine(addr) + #if self.lnworker: + # assert self.lnworker.wallet.is_mine(addr) # FIXME xxxxx chan should be deterministic. NEEDS to be fixed before merge return addr def has_anchors(self) -> bool: diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 085fc7576..73fc701d7 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -81,7 +81,7 @@ class Peer(Logger, EventListener): def __init__( self, - lnworker: Union['LNGossip', 'LNWallet'], + lnworker: Union['LNWallet', 'LNGossip'], pubkey: bytes, transport: LNTransportBase, *, is_channel_backup= False): @@ -402,7 +402,7 @@ class Peer(Logger, EventListener): if constants.net.rev_genesis_bytes() not in their_chains: raise GracefulDisconnect(f"no common chain found with remote. (they sent: {their_chains})") # all checks passed - self.lnworker.on_peer_successfully_established(self) + self.lnworker.lnpeermgr.on_peer_successfully_established(self) self._received_init = True self.maybe_set_initialized() @@ -888,7 +888,7 @@ class Peer(Logger, EventListener): self.transport.close() except Exception: pass - self.lnworker.peer_closed(self) + self.lnworker.lnpeermgr.peer_closed(self) self.got_disconnected.set() def is_shutdown_anysegwit(self): @@ -3064,7 +3064,7 @@ class Peer(Logger, EventListener): or not self.lnworker.is_payment_bundle_complete(payment_key): # maybe this set is COMPLETE but the bundle is not yet completed, so the bundle can be considered WAITING if int(time.time()) - first_htlc_timestamp > self.lnworker.MPP_EXPIRY \ - or self.lnworker.stopping_soon: + or self.lnworker.lnpeermgr.stopping_soon: _log_fail_reason(f"MPP TIMEOUT (> {self.lnworker.MPP_EXPIRY} sec)") return OnionFailureCode.MPP_TIMEOUT, None, None diff --git a/electrum/lnworker.py b/electrum/lnworker.py index dddf5350c..6add346ad 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -39,7 +39,7 @@ from .util import ( profiler, OldTaskGroup, ESocksProxy, NetworkRetryManager, JsonRPCClient, NotEnoughFunds, EventListener, event_listener, bfh, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions, ignore_exceptions, make_aiohttp_session, random_shuffled_copy, is_private_netaddress, - UnrelatedTransactionException, LightningHistoryItem + UnrelatedTransactionException, LightningHistoryItem, get_asyncio_loop, ) from .fee_policy import ( FeePolicy, FEERATE_FALLBACK_STATIC_FEE, FEE_LN_ETA_TARGET, FEE_LN_LOW_ETA_TARGET, @@ -215,9 +215,15 @@ LNGOSSIP_FEATURES = ( ) -class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): +class LNPeerManager(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): - def __init__(self, node_keypair, features: LnFeatures, *, config: 'SimpleConfig'): + def __init__( + self, node_keypair, + *, + lnwallet_or_lngossip: 'LNWallet | LNGossip', + features: LnFeatures, + config: 'SimpleConfig', + ): Logger.__init__(self) NetworkRetryManager.__init__( self, @@ -228,6 +234,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): ) self.lock = threading.RLock() self.node_keypair = node_keypair + self._lnwallet_or_lngossip = lnwallet_or_lngossip self._peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer # needs self.lock self._channelless_incoming_peers = set() # type: Set[bytes] # node_ids # needs self.lock self.taskgroup = OldTaskGroup() @@ -252,7 +259,10 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): return self._peers.copy() def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]: - return {} + return self._lnwallet_or_lngossip.channels_for_peer(node_id) + + def get_peer_by_pubkey(self, pubkey: bytes) -> Optional[Peer]: + return self._peers.get(pubkey) def get_node_alias(self, node_id: bytes) -> Optional[str]: """Returns the alias of the node, or None if unknown.""" @@ -269,7 +279,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): return node_alias async def maybe_listen(self): - # FIXME: only one LNWorker can listen at a time (single port) + # FIXME: only one LNPeerManager can listen at a time (single port) listen_addr = self.config.LIGHTNING_LISTEN if listen_addr: self.logger.info(f'lightning_listen enabled. will try to bind: {listen_addr!r}') @@ -368,13 +378,17 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): return None self._channelless_incoming_peers.add(node_id) # checks done: we are adding this peer. - peer = Peer(self, node_id, transport) + peer = Peer(self._lnwallet_or_lngossip, node_id, transport) assert node_id not in self._peers self._peers[node_id] = peer await self.taskgroup.spawn(peer.main_loop()) return peer def peer_closed(self, peer: Peer) -> None: + if isinstance(self._lnwallet_or_lngossip, LNWallet): + for chan in self.channels_for_peer(peer.pubkey).values(): + chan.peer_state = PeerState.DISCONNECTED + util.trigger_callback('channel', self._lnwallet_or_lngossip.wallet, chan) with self.lock: peer2 = self._peers.get(peer.pubkey) if peer2 is peer: @@ -392,14 +406,25 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): return True return False - def start_network(self, network: 'Network'): + def start_network( + self, network: 'Network', *, + listen: bool = False, + maintain_random_peers: bool = False, + ) -> None: assert network assert self.network is None, "already started" self.network = network self._add_peers_from_config() - asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop) + asyncio.run_coroutine_threadsafe(self.main_loop(), get_asyncio_loop()) + if listen: + tg_coro = self.taskgroup.spawn(self.maybe_listen()) + asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop()) + if maintain_random_peers: + tg_coro = self.taskgroup.spawn(self._maintain_connectivity()) + asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop()) async def stop(self): + self.stopping_soon = True if self.listen_server: self.listen_server.close() self.unregister_callbacks() @@ -410,7 +435,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): for host, port, pubkey in peer_list: asyncio.run_coroutine_threadsafe( self._add_peer(host, int(port), bfh(pubkey)), - self.network.asyncio_loop) + get_asyncio_loop()) def is_good_peer(self, peer: LNPeerAddr) -> bool: # the purpose of this method is to filter peers that advertise the desired feature bits @@ -573,8 +598,44 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): peer = await self._add_peer(host, port, node_id) return peer + async def reestablish_peer_for_given_channel(self, chan: Channel) -> None: + await self.taskgroup.spawn(self._reestablish_peer_for_given_channel(chan)) -class LNGossip(LNWorker): + @ignore_exceptions + @log_exceptions + async def _reestablish_peer_for_given_channel(self, chan: Channel) -> None: + now = time.time() + peer_addresses = [] + if self.uses_trampoline(): + addr = trampolines_by_id().get(chan.node_id) + if addr: + peer_addresses.append(addr) + else: + # will try last good address first, from gossip + last_good_addr = self.channel_db.get_last_good_address(chan.node_id) + if last_good_addr: + peer_addresses.append(last_good_addr) + # will try addresses for node_id from gossip + addrs_from_gossip = self.channel_db.get_node_addresses(chan.node_id) or [] + for host, port, ts in addrs_from_gossip: + peer_addresses.append(LNPeerAddr(host, port, chan.node_id)) + # will try addresses stored in channel storage + peer_addresses += list(chan.get_peer_addresses()) + # Done gathering addresses. + # Now select first one that has not failed recently. + for peer in peer_addresses: + if self._can_retry_addr(peer, urgent=True, now=now): + await self._add_peer(peer.host, peer.port, peer.pubkey) + return + + async def reestablish_peer_for_zero_conf_trusted_node(self) -> None: + if self.config.ZEROCONF_TRUSTED_NODE: + peer = LNPeerAddr.from_str(self.config.ZEROCONF_TRUSTED_NODE) + if self._can_retry_addr(peer, urgent=True): + await self._add_peer(peer.host, peer.port, peer.pubkey) + + +class LNGossip(Logger): """The LNGossip class is a separate, unannounced Lightning node with random id that is just querying gossip from other nodes. The LNGossip node does not satisfy gossip queries, this is done by the LNWallet class(es). LNWallets are the advertised nodes used for actual payments and only satisfy @@ -584,11 +645,14 @@ class LNGossip(LNWorker): max_age = 14*24*3600 def __init__(self, config: 'SimpleConfig'): + self.config = config seed = os.urandom(32) node = BIP32Node.from_rootseed(seed, xtype='standard') xprv = node.to_xprv() node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY) - LNWorker.__init__(self, node_keypair, LNGOSSIP_FEATURES, config=config) + Logger.__init__(self) + self.lnpeermgr = LNPeerManager(node_keypair, features=LNGOSSIP_FEATURES, config=self.config, lnwallet_or_lngossip=self) + self.taskgroup = OldTaskGroup() self.unknown_ids = set() self._forwarding_gossip = [] # type: List[GossipForwardingMessage] self._last_gossip_batch_ts = 0 # type: int @@ -600,15 +664,44 @@ class LNGossip(LNWorker): self._num_chan_upd = 0 self._num_chan_upd_good = 0 + @property + def features(self) -> 'LnFeatures': + return self.lnpeermgr.features + + @property + def network(self) -> Optional['Network']: + return self.lnpeermgr.network + + @property + def channel_db(self) -> 'ChannelDB': + return self.network.channel_db if self.network else None + + def uses_trampoline(self) -> bool: + return not bool(self.channel_db) + + async def main_loop(self): + self.logger.info("starting taskgroup.") + try: + async with self.taskgroup as group: + await group.spawn(asyncio.Event().wait) # run forever (until cancel) + except Exception as e: + self.logger.exception("taskgroup died.") + finally: + self.logger.info("taskgroup stopped.") + def start_network(self, network: 'Network'): - super().start_network(network) + asyncio.run_coroutine_threadsafe(self.main_loop(), get_asyncio_loop()) + self.lnpeermgr.start_network(network, maintain_random_peers=True) for coro in [ - self._maintain_connectivity(), self.maintain_db(), self._maintain_forwarding_gossip() ]: tg_coro = self.taskgroup.spawn(coro) - asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop) + asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop()) + + async def stop(self): + await self.lnpeermgr.stop() + await self.taskgroup.cancel_remaining() async def maintain_db(self): await self.channel_db.data_loaded.wait() @@ -637,7 +730,7 @@ class LNGossip(LNWorker): new = set(ids) - set(known) self.unknown_ids.update(new) util.trigger_callback('unknown_channels', len(self.unknown_ids)) - util.trigger_callback('gossip_peers', self.num_peers()) + util.trigger_callback('gossip_peers', self.lnpeermgr.num_peers()) util.trigger_callback('ln_gossip_sync_progress') def get_ids_to_query(self) -> Sequence[bytes]: @@ -652,7 +745,7 @@ class LNGossip(LNWorker): """Estimates the gossip synchronization process and returns the number of synchronized channels, the total channels in the network and a rescaled percentage of the synchronization process.""" - if self.num_peers() == 0: + if self.lnpeermgr.num_peers() == 0: return None, None, None nchans_with_0p, nchans_with_1p, nchans_with_2p = self.channel_db.get_num_channels_partitioned_by_policy_count() num_db_channels = nchans_with_0p + nchans_with_1p + nchans_with_2p @@ -730,6 +823,9 @@ class LNGossip(LNWorker): # flush the gossip queue so we don't forward old gossip after sync is complete self.channel_db.get_forwarding_gossip_batch() + def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]: + return {} + class PaySession(Logger): @@ -875,7 +971,7 @@ class PaySession(Logger): return nhtlcs_resolved == self._nhtlcs_inflight -class LNWallet(LNWorker): +class LNWallet(Logger): lnwatcher: Optional['LNWatcher'] MPP_EXPIRY = 120 @@ -884,7 +980,7 @@ class LNWallet(LNWorker): MPP_SPLIT_PART_FRACTION = 0.2 MPP_SPLIT_PART_MINAMT_MSAT = 5_000_000 - def __init__(self, wallet: 'Abstract_Wallet', xprv): + def __init__(self, wallet: 'Abstract_Wallet', xprv, *, features: LnFeatures = None): self.wallet = wallet self.config = wallet.config self.db = wallet.db @@ -894,16 +990,20 @@ class LNWallet(LNWorker): self.payment_secret_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.PAYMENT_SECRET_KEY).privkey self.funding_root_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.FUNDING_ROOT_KEY) Logger.__init__(self) - features = LNWALLET_FEATURES - if self.config.ENABLE_ANCHOR_CHANNELS: - features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT - if self.config.ACCEPT_ZEROCONF_CHANNELS: - features |= LnFeatures.OPTION_ZEROCONF_OPT - if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS or self.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS: - features |= LnFeatures.OPTION_ONION_MESSAGE_OPT - if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP: - features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch - LNWorker.__init__(self, self.node_keypair, features, config=self.config) + if features is None: + features = LNWALLET_FEATURES + if self.config.ENABLE_ANCHOR_CHANNELS: + features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT + if self.config.ACCEPT_ZEROCONF_CHANNELS: + features |= LnFeatures.OPTION_ZEROCONF_OPT + if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS or self.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS: + features |= LnFeatures.OPTION_ONION_MESSAGE_OPT + if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP: + features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch + Logger.__init__(self) + self.lock = threading.RLock() + self.lnpeermgr = LNPeerManager(self.node_keypair, features=features, config=self.config, lnwallet_or_lngossip=self) + self.taskgroup = OldTaskGroup() self.lnwatcher = LNWatcher(self) self.lnrater: LNRater = None # "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts, invoice_features @@ -997,6 +1097,21 @@ class LNWallet(LNWorker): return any(chan.has_anchors() and not chan.is_closed() for chan in self.channels.values()) + @property + def features(self) -> 'LnFeatures': + return self.lnpeermgr.features + + @property + def network(self) -> Optional['Network']: + return self.lnpeermgr.network + + @property + def channel_db(self) -> 'ChannelDB': + return self.network.channel_db if self.network else None + + def uses_trampoline(self) -> bool: + return not bool(self.channel_db) + @property def channels(self) -> Mapping[bytes, Channel]: """Returns a read-only copy of channels.""" @@ -1060,29 +1175,39 @@ class LNWallet(LNWorker): await watchtower.add_sweep_tx(outpoint, ctn, tx.inputs()[0].prevout.to_str(), tx.serialize()) self.watchtower_ctns[outpoint] = ctn + async def main_loop(self): + self.logger.info("starting taskgroup.") + try: + async with self.taskgroup as group: + await group.spawn(asyncio.Event().wait) # run forever (until cancel) + except Exception as e: + self.logger.exception("taskgroup died.") + finally: + self.logger.info("taskgroup stopped.") + def start_network(self, network: 'Network'): - super().start_network(network) + asyncio.run_coroutine_threadsafe(self.main_loop(), get_asyncio_loop()) + self.lnpeermgr.start_network(network, listen=True) self.lnwatcher.start_network(network) self.swap_manager.start_network(network) self.lnrater = LNRater(self, network) self.onion_message_manager.start_network(network=network) for coro in [ - self.maybe_listen(), self.lnwatcher.trigger_callbacks(), # shortcut (don't block) if funding tx locked and verified self.reestablish_peers_and_channels(), self.sync_with_remote_watchtower(), ]: tg_coro = self.taskgroup.spawn(coro) - asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop) + asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop()) async def stop(self): - self.stopping_soon = True - if self.listen_server: # stop accepting new peers - self.listen_server.close() + self.lnpeermgr.stopping_soon = True + if self.lnpeermgr.listen_server: # stop accepting new peers + self.lnpeermgr.listen_server.close() async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS): await self.wait_for_received_pending_htlcs_to_get_removed() - await LNWorker.stop(self) + await self.lnpeermgr.stop() if self.lnwatcher: self.lnwatcher.stop() self.lnwatcher = None @@ -1090,30 +1215,25 @@ class LNWallet(LNWorker): await self.swap_manager.stop() if self.onion_message_manager: await self.onion_message_manager.stop() + await self.taskgroup.cancel_remaining() async def wait_for_received_pending_htlcs_to_get_removed(self): - assert self.stopping_soon is True + assert self.lnpeermgr.stopping_soon is True # We try to fail pending MPP HTLCs, and wait a bit for them to get removed. # Note: even without MPP, if we just failed/fulfilled an HTLC, it is good # to wait a bit for it to become irrevocably removed. # Note: we don't wait for *all htlcs* to get removed, only for those # that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed async with OldTaskGroup() as group: - for peer in self.peers.values(): + for peer in self.lnpeermgr.peers.values(): await group.spawn(peer.wait_one_htlc_switch_iteration()) while True: - if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()): + if all(not peer.received_htlcs_pending_removal for peer in self.lnpeermgr.peers.values()): break async with OldTaskGroup(wait=any) as group: - for peer in self.peers.values(): + for peer in self.lnpeermgr.peers.values(): await group.spawn(peer.received_htlc_removed_event.wait()) - def peer_closed(self, peer): - for chan in self.channels_for_peer(peer.pubkey).values(): - chan.peer_state = PeerState.DISCONNECTED - util.trigger_callback('channel', self.wallet, chan) - super().peer_closed(peer) - def get_payments(self, *, status=None) -> Mapping[bytes, List[HTLCWithStatus]]: out = defaultdict(list) for chan in self.channels.values(): @@ -1263,7 +1383,7 @@ class LNWallet(LNWorker): node_ids = [chan.node_id for chan in self.channels.values() if not chan.is_closed()] return node_ids - def channels_for_peer(self, node_id): + def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]: assert type(node_id) is bytes return {chan_id: chan for (chan_id, chan) in self.channels.items() if chan.node_id == node_id} @@ -1307,12 +1427,12 @@ class LNWallet(LNWorker): await self.schedule_force_closing(chan.channel_id) elif chan.get_state() == ChannelState.FUNDED: - peer = self._peers.get(chan.node_id) + peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id) if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD: peer.send_channel_ready(chan) elif chan.get_state() == ChannelState.OPEN: - peer = self._peers.get(chan.node_id) + peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id) if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD: peer.maybe_update_fee(chan) peer.maybe_send_announcement_signatures(chan) @@ -1326,9 +1446,10 @@ class LNWallet(LNWorker): await self.network.try_broadcasting(force_close_tx, 'force-close') def get_peer_by_static_jit_scid_alias(self, scid_alias: bytes) -> Optional[Peer]: - for nodeid, peer in self.peers.items(): + for nodeid, peer in self.lnpeermgr.peers.items(): if scid_alias == self._scid_alias_of_node(nodeid): return peer + return None def _scid_alias_of_node(self, nodeid: bytes) -> bytes: # scid alias for just-in-time channels @@ -1557,7 +1678,7 @@ class LNWallet(LNWorker): password: str = None, ) -> Tuple[Channel, PartialTransaction]: - fut = asyncio.run_coroutine_threadsafe(self.add_peer(connect_str), self.network.asyncio_loop) + fut = asyncio.run_coroutine_threadsafe(self.lnpeermgr.add_peer(connect_str), get_asyncio_loop()) try: peer = fut.result() except concurrent.futures.TimeoutError: @@ -1569,7 +1690,7 @@ class LNWallet(LNWorker): push_sat=push_amt_sat, public=public, password=password) - fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) + fut = asyncio.run_coroutine_threadsafe(coro, get_asyncio_loop()) try: chan, funding_tx = fut.result() except concurrent.futures.TimeoutError: @@ -1860,7 +1981,7 @@ class LNWallet(LNWorker): short_channel_id = shi.route[0].short_channel_id chan = self.get_channel_by_short_id(short_channel_id) assert chan, ShortChannelID(short_channel_id) - peer = self._peers.get(shi.route[0].node_id) + peer = self.lnpeermgr.get_peer_by_pubkey(shi.route[0].node_id) if not peer: raise PaymentFailure('Dropped peer') await peer.initialized @@ -2040,7 +2161,7 @@ class LNWallet(LNWorker): # until trampoline is advertised in lnfeatures, check against hardcoded list if is_hardcoded_trampoline(node_id): return True - peer = self._peers.get(node_id) + peer = self.lnpeermgr.get_peer_by_pubkey(node_id) if not peer: return False return (peer.their_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ECLAIR) @@ -2794,7 +2915,7 @@ class LNWallet(LNWorker): return upstream_chan_scid, _ = deserialize_htlc_key(upstream_key) upstream_chan = self.get_channel_by_short_id(upstream_chan_scid) - upstream_peer = self.peers.get(upstream_chan.node_id) if upstream_chan else None + upstream_peer = self.lnpeermgr.get_peer_by_pubkey(upstream_chan.node_id) if upstream_chan else None if upstream_peer: upstream_peer.downstream_htlc_resolved_event.set() upstream_peer.downstream_htlc_resolved_event.clear() @@ -3110,7 +3231,7 @@ class LNWallet(LNWorker): # invalid connection string return False # only return True if we are connected to the zeroconf provider - return node_id in self.peers + return self.lnpeermgr.get_peer_by_pubkey(node_id) is not None def _suggest_channels_for_rebalance(self, direction, amount_sat) -> Sequence[Tuple[Channel, int]]: """ @@ -3239,7 +3360,9 @@ class LNWallet(LNWorker): async def close_channel(self, chan_id): chan = self._channels[chan_id] - peer = self._peers[chan.node_id] + peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id) + if peer is None: + raise KeyError return await peer.close_channel(chan_id) def _force_close_channel(self, chan_id: bytes) -> Transaction: @@ -3288,52 +3411,23 @@ class LNWallet(LNWorker): util.trigger_callback('channels_updated', self.wallet) util.trigger_callback('wallet_updated', self.wallet) - @ignore_exceptions - @log_exceptions - async def reestablish_peer_for_given_channel(self, chan: Channel) -> None: - now = time.time() - peer_addresses = [] - if self.uses_trampoline(): - addr = trampolines_by_id().get(chan.node_id) - if addr: - peer_addresses.append(addr) - else: - # will try last good address first, from gossip - last_good_addr = self.channel_db.get_last_good_address(chan.node_id) - if last_good_addr: - peer_addresses.append(last_good_addr) - # will try addresses for node_id from gossip - addrs_from_gossip = self.channel_db.get_node_addresses(chan.node_id) or [] - for host, port, ts in addrs_from_gossip: - peer_addresses.append(LNPeerAddr(host, port, chan.node_id)) - # will try addresses stored in channel storage - peer_addresses += list(chan.get_peer_addresses()) - # Done gathering addresses. - # Now select first one that has not failed recently. - for peer in peer_addresses: - if self._can_retry_addr(peer, urgent=True, now=now): - await self._add_peer(peer.host, peer.port, peer.pubkey) - return - async def reestablish_peers_and_channels(self): while True: await asyncio.sleep(1) - if self.stopping_soon: + if self.lnpeermgr.stopping_soon: return - if self.config.ZEROCONF_TRUSTED_NODE: - peer = LNPeerAddr.from_str(self.config.ZEROCONF_TRUSTED_NODE) - if self._can_retry_addr(peer, urgent=True): - await self._add_peer(peer.host, peer.port, peer.pubkey) + await self.lnpeermgr.reestablish_peer_for_zero_conf_trusted_node() for chan in self.channels.values(): # reestablish # note: we delegate filtering out uninteresting chans to this: if not chan.should_try_to_reestablish_peer(): continue - peer = self._peers.get(chan.node_id, None) + peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id) if peer: + # FIXME maybe this should be the responsibility of the peer itself, done in peer.main_loop: await peer.taskgroup.spawn(peer.reestablish_channel(chan)) else: - await self.taskgroup.spawn(self.reestablish_peer_for_given_channel(chan)) + await self.lnpeermgr.reestablish_peer_for_given_channel(chan) def current_target_feerate_per_kw(self, *, has_anchors: bool) -> Optional[int]: target: int = FEE_LN_MINIMUM_ETA_TARGET if has_anchors else FEE_LN_ETA_TARGET @@ -3396,12 +3490,12 @@ class LNWallet(LNWorker): async def request_force_close(self, channel_id: bytes, *, connect_str=None) -> None: if chan := self.get_channel_by_id(channel_id): - peer = self._peers.get(chan.node_id) + peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id) chan.should_request_force_close = True if peer: peer.close_and_cleanup() # to force a reconnect elif connect_str: - peer = await self.add_peer(connect_str) + peer = await self.lnpeermgr.add_peer(connect_str) await peer.request_force_close(channel_id) elif channel_id in self.channel_backups: await self._request_force_close_from_backup(channel_id) @@ -3688,7 +3782,7 @@ class LNWallet(LNWorker): f"maybe_forward_htlc. will forward HTLC: inc_chan={incoming_chan.short_channel_id}. inc_htlc={str(htlc)}. " f"next_chan={next_chan.get_id_for_log()}.") - next_peer = self.peers.get(next_chan.node_id) + next_peer = self.lnpeermgr.get_peer_by_pubkey(next_chan.node_id) if next_peer is None: log_fail_reason(f"next_peer offline ({next_chan.node_id.hex()})") raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message) @@ -3774,7 +3868,7 @@ class LNWallet(LNWorker): raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'') # do we have a connection to the node? - next_peer = self.peers.get(outgoing_node_id) + next_peer = self.lnpeermgr.get_peer_by_pubkey(outgoing_node_id) if next_peer and next_peer.accepts_zeroconf(): self.logger.info(f'JIT: found next_peer') for next_chan in next_peer.channels.values(): diff --git a/electrum/onion_message.py b/electrum/onion_message.py index 121235344..a6225f3eb 100644 --- a/electrum/onion_message.py +++ b/electrum/onion_message.py @@ -162,7 +162,7 @@ def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> Seque ): return path # alt: dest is existing peer? - if lnwallet.peers.get(node_id): + if lnwallet.lnpeermgr.get_peer_by_pubkey(node_id): return [PathEdge(short_channel_id=None, start_node=None, end_node=node_id)] # if we have an address, pass it. @@ -219,7 +219,7 @@ def send_onion_message_to( encrypted_recipient_data=our_payload['encrypted_recipient_data'] ) - peer = lnwallet.peers.get(recipient_data['next_node_id']['node_id']) + peer = lnwallet.lnpeermgr.get_peer_by_pubkey(recipient_data['next_node_id']['node_id']) assert peer, 'next_node_id not a peer' # blinding override? @@ -241,7 +241,7 @@ def send_onion_message_to( if not isinstance(remaining_blinded_path, list): # doesn't return list when num items == 1 remaining_blinded_path = [remaining_blinded_path] - peer = lnwallet.peers.get(introduction_point) + peer = lnwallet.lnpeermgr.get_peer_by_pubkey(introduction_point) # if blinded path introduction point is our direct peer, no need to route-find if peer: # start of blinded path is our peer @@ -250,7 +250,7 @@ def send_onion_message_to( path = create_onion_message_route_to(lnwallet, introduction_point) # first edge must be to our peer - peer = lnwallet.peers.get(path[0].end_node) + peer = lnwallet.lnpeermgr.get_peer_by_pubkey(path[0].end_node) assert peer, 'first hop not a peer' # last edge is to introduction point and start of blinded path. remove from route @@ -321,7 +321,7 @@ def send_onion_message_to( raise Exception('cannot send to myself') hops_data = [] - peer = lnwallet.peers.get(pubkey) + peer = lnwallet.lnpeermgr.get_peer_by_pubkey(pubkey) if peer: # destination is our direct peer, no need to route-find @@ -330,7 +330,7 @@ def send_onion_message_to( path = create_onion_message_route_to(lnwallet, pubkey) # first edge must be to our peer - peer = lnwallet.peers.get(path[0].end_node) + peer = lnwallet.lnpeermgr.get_peer_by_pubkey(path[0].end_node) assert peer, 'first hop not a peer' hops_data = [ @@ -379,9 +379,9 @@ def get_blinded_reply_paths( - reply_path introduction points are direct peers only (TODO: longer reply paths)""" # TODO: build longer paths and/or add dummy hops to increase privacy my_active_channels = [chan for chan in lnwallet.channels.values() if chan.is_active()] - my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.peers.get(chan.node_id) and - lnwallet.peers.get(chan.node_id).their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)] - my_onionmsg_peers = [peer for peer in lnwallet.peers.values() if peer.their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)] + my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.lnpeermgr.get_peer_by_pubkey(chan.node_id) and + lnwallet.lnpeermgr.get_peer_by_pubkey(chan.node_id).their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)] + my_onionmsg_peers = [peer for peer in lnwallet.lnpeermgr.peers.values() if peer.their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)] result = [] mynodeid = lnwallet.node_keypair.pubkey @@ -472,7 +472,7 @@ class OnionMessageManager(Logger): try: onion_packet_b = onion_packet.to_bytes() - next_peer = self.lnwallet.peers.get(node_id) + next_peer = self.lnwallet.lnpeermgr.get_peer_by_pubkey(node_id) if not next_peer.their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT): self.logger.debug('forward dropped, next peer is not ONION_MESSAGE capable') @@ -528,7 +528,7 @@ class OnionMessageManager(Logger): req.future.set_exception(copy.copy(e)) # NOTE: above, when passing the caught exception instance e directly it leads to GeneratorExit() in if isinstance(e, NoRouteFound) and e.peer_address: - await self.lnwallet.add_peer(str(e.peer_address)) + await self.lnwallet.lnpeermgr.add_peer(str(e.peer_address)) else: self.logger.debug(f'resubmit {key=}') self.send_queue.put_nowait((now() + self.REQUEST_REPLY_RETRY_DELAY, expires, key)) @@ -700,7 +700,7 @@ class OnionMessageManager(Logger): 'onion_message dropped (not forwarding due to lightning_forward_payments config option disabled') return # is next_node one of our peers? - next_peer = self.lnwallet.peers.get(next_node_id) + next_peer = self.lnwallet.lnpeermgr.get_peer_by_pubkey(next_node_id) if not next_peer: self.logger.info(f'next node {next_node_id.hex()} not a peer, dropping message') return diff --git a/electrum/scripts/ln_features.py b/electrum/scripts/ln_features.py index 954fdbc56..8e99290a0 100644 --- a/electrum/scripts/ln_features.py +++ b/electrum/scripts/ln_features.py @@ -81,7 +81,7 @@ async def worker(work_queue: asyncio.Queue, results_queue: asyncio.Queue, flag): print(f"worker connecting to {connect_str}") try: - peer = await wallet.lnworker.add_peer(connect_str) + peer = await wallet.lnworker.lnpeermgr.add_peer(connect_str) res = await util.wait_for2(peer.initialized, TIMEOUT) if res: if peer.features & flag == work['features'] & flag: diff --git a/electrum/wallet.py b/electrum/wallet.py index 88002e96a..5071325f1 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -3451,7 +3451,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): lightning_has_channels = ( self.lnworker and len([chan for chan in self.lnworker.channels.values() if chan.is_open()]) > 0 ) - lightning_online = self.lnworker and self.lnworker.num_peers() > 0 + lightning_online = self.lnworker and self.lnworker.lnpeermgr.num_peers() > 0 num_sats_can_receive = self.lnworker.num_sats_can_receive() if self.lnworker else 0 can_receive_lightning = self.lnworker and num_sats_can_receive > 0 and amount_sat <= num_sats_can_receive try: @@ -3459,7 +3459,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): except Exception: zeroconf_nodeid = None can_get_zeroconf_channel = (self.lnworker and self.config.ACCEPT_ZEROCONF_CHANNELS - and zeroconf_nodeid in self.lnworker.peers) + and self.lnworker.lnpeermgr.get_peer_by_pubkey(zeroconf_nodeid) is not None) status = self.get_invoice_status(req) if status == PR_EXPIRED: diff --git a/tests/test_commands.py b/tests/test_commands.py index 969bfb024..1acf4c68d 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -760,22 +760,23 @@ class TestCommandsTestnet(ElectrumTestCase): # Mock the network and lnworker mock_lnworker = mock.Mock() + mock_lnworker.lnpeermgr = mock.Mock() w.lnworker = mock_lnworker mock_peer = mock.Mock() mock_peer.initialized = asyncio.Future() connection_string = "test_node_id@127.0.0.1:9735" called = False - async def lnworker_add_peer(*args, **kwargs): + async def lnpeermgr_add_peer(*args, **kwargs): assert args[0] == connection_string nonlocal called called += 1 return mock_peer - mock_lnworker.add_peer = lnworker_add_peer + mock_lnworker.lnpeermgr.add_peer = lnpeermgr_add_peer # check if add_peer times out if peer doesn't initialize (LN_P2P_NETWORK_TIMEOUT is 0.001s) with self.assertRaises(UserFacingException): await cmds.add_peer(connection_string=connection_string, wallet=w) - # check if add_peer called lnworker.add_peer + # check if add_peer called lnpeermgr.add_peer assert called == 1 mock_peer.initialized = asyncio.Future() diff --git a/tests/test_lnchannel.py b/tests/test_lnchannel.py index 149af3748..b55281ca6 100644 --- a/tests/test_lnchannel.py +++ b/tests/test_lnchannel.py @@ -23,6 +23,7 @@ # (around commit 42de4400bff5105352d0552155f73589166d162b). import unittest +from functools import lru_cache from unittest import mock import os import binascii @@ -40,7 +41,7 @@ from electrum.crypto import privkey_to_pubkey from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED, UpdateAddHtlc from electrum.lnutil import effective_htlc_tx_weight from electrum.logging import console_stderr_handler -from electrum.lnchannel import ChannelState +from electrum.lnchannel import ChannelState, Channel from electrum.json_db import StoredDict from electrum.coinchooser import PRNG @@ -124,6 +125,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator, return StoredDict(state, None) +@lru_cache() def bip32(sequence): node = bip32_utils.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence) k = node.eckey.get_secret_bytes() @@ -137,7 +139,7 @@ def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None, alice_pubkey=b"\x01"*33, bob_pubkey=b"\x02"*33, random_seed=None, anchor_outputs=False, local_max_inflight=None, remote_max_inflight=None, - max_accepted_htlcs=5): + max_accepted_htlcs=5) -> tuple[Channel, Channel]: if random_seed is None: # needed for deterministic randomness random_seed = os.urandom(32) random_gen = PRNG(random_seed) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 8b9f49def..7f8f36ce5 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -10,6 +10,7 @@ from collections import defaultdict import logging import concurrent from concurrent import futures +from functools import lru_cache from unittest import mock from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence from types import MappingProxyType @@ -24,6 +25,7 @@ import electrum.trampoline from electrum import bitcoin from electrum import util from electrum import constants +from electrum import bip32 from electrum.network import Network from electrum import simple_config, lnutil from electrum.lnaddr import lnencode, LnAddr, lndecode @@ -37,7 +39,7 @@ from electrum.lnutil import Keypair, PaymentFailure, LnFeatures, HTLCOwner, Paym from electrum.lnchannel import ChannelState, PeerState, Channel from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent from electrum.channel_db import ChannelDB -from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession +from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession, LNPeerManager from electrum.lnmsg import encode_msg, decode_msg from electrum import lnmsg from electrum.logging import console_stderr_handler, Logger @@ -49,10 +51,11 @@ from electrum.interface import GracefulDisconnect from electrum.simple_config import SimpleConfig from electrum.fee_policy import FeeTimeEstimates, FEE_ETA_TARGETS from electrum.mpp_split import split_amount_normal +from electrum.wallet import Abstract_Wallet from .test_lnchannel import create_test_channels from .test_bitcoin import needs_test_with_all_chacha20_implementations -from . import ElectrumTestCase +from . import ElectrumTestCase, restore_wallet_from_text__for_unittest def keypair(): @@ -62,9 +65,6 @@ def keypair(): privkey=priv) return k1 -@contextmanager -def noop_lock(): - yield class MockNetwork: def __init__(self, tx_queue, *, config: SimpleConfig): @@ -120,144 +120,100 @@ class MockADB: def get_local_height(self): return self._blockchain.height() -class MockWallet: - receive_requests = {} - adb = MockADB() - - def get_invoice(self, key): - pass - - def get_request(self, key): - pass - - def get_key_for_receive_request(self, x): - pass - - def set_label(self, x, y): - pass - - def save_db(self): - pass - - def is_lightning_backup(self): - return False - - def is_mine(self, addr): - return True - - def get_fingerprint(self): - return '' - - def get_new_sweep_address_for_channel(self): - # note: sweep is not tested here, only in regtest - return "tb1qqu5newtapamjchgxf0nty6geuykhvwas45q4q4" - - def is_up_to_date(self): - return True - class MockLNGossip: def get_sync_progress_estimate(self): return None, None, None -class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): +class MockLNPeerManager(LNPeerManager): + def __init__( + self, + *, + node_keypair, + config: SimpleConfig, + features: LnFeatures, + lnwallet: LNWallet, + network: 'MockNetwork', + ): + LNPeerManager.__init__( + self, + node_keypair=node_keypair, + lnwallet_or_lngossip=lnwallet, + features=features, + config=config, + ) + self.network = network + + +@lru_cache() +def _bip32_from_name(name: str) -> bip32.BIP32Node: + # note: unlike a serialized xprv, the bip32 node can be cached easily, + # as it does not depend on constant.net (testnet/mainnet) network bytes + sequence = [ord(c) for c in name] + bip32_node = bip32.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence) + return bip32_node + + +class MockLNWallet(LNWallet): MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1 - PAYMENT_TIMEOUT = 120 TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0 MPP_SPLIT_PART_FRACTION = 1 # this disables the forced splitting - MPP_SPLIT_PART_MINAMT_MSAT = 5_000_000 - def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name, has_anchors): + def __init__(self, *, tx_queue, name, has_anchors, ln_xprv: str = None): self.name = name - Logger.__init__(self) - NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1) - self.node_keypair = local_keypair - self.payment_secret_key = os.urandom(32) # does not need to be deterministic in tests + self._user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-") self.config = SimpleConfig({}, read_user_dir_function=lambda: self._user_dir) - self.network = MockNetwork(tx_queue, config=self.config) - self.taskgroup = OldTaskGroup() + self.config.ENABLE_ANCHOR_CHANNELS = has_anchors + self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0 + + network = MockNetwork(tx_queue, config=self.config) + + wallet = restore_wallet_from_text__for_unittest( + "9dk", path=None, passphrase=name, config=self.config)['wallet'] # type: Abstract_Wallet + wallet.is_up_to_date = lambda: True + wallet.adb.network = wallet.network = network + + features = LnFeatures(0) + features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT + features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT + features |= LnFeatures.VAR_ONION_OPT + features |= LnFeatures.PAYMENT_SECRET_OPT + features |= LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM + features |= LnFeatures.OPTION_CHANNEL_TYPE_OPT + features |= LnFeatures.OPTION_SCID_ALIAS_OPT + features |= LnFeatures.OPTION_STATIC_REMOTEKEY_OPT + + if ln_xprv is None: + ln_xprv = _bip32_from_name(name).to_xprv() + LNWallet.__init__(self, wallet=wallet, xprv=ln_xprv, features=features) + + self.lnpeermgr = MockLNPeerManager( + node_keypair=self.node_keypair, + config=self.config, + features=features, + lnwallet=self, + network=network, + ) self.lnwatcher = None self.swap_manager = None self.onion_message_manager = None self.listen_server = None - self._channels = {chan.channel_id: chan for chan in chans} - self.payment_info = {} - self.logs = defaultdict(list) - self.wallet = MockWallet() - self.features = LnFeatures(0) - self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT - self.features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT - self.features |= LnFeatures.VAR_ONION_OPT - self.features |= LnFeatures.PAYMENT_SECRET_OPT - self.features |= LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM - self.features |= LnFeatures.OPTION_CHANNEL_TYPE_OPT - self.features |= LnFeatures.OPTION_SCID_ALIAS_OPT - self.features |= LnFeatures.OPTION_STATIC_REMOTEKEY_OPT - self.config.ENABLE_ANCHOR_CHANNELS = has_anchors - for chan in chans: - chan.lnworker = self - self._peers = {} # bytes -> Peer - # used in tests - self.enable_htlc_settle = True - self.enable_htlc_forwarding = True - self.received_mpp_htlcs = dict() - self._paysessions = dict() - self.sent_htlcs_info = dict() - self.sent_buckets = defaultdict(set) - self.active_forwardings = {} - self.forwarding_failures = {} - self.inflight_payments = set() - self._preimages = {} - self.stopping_soon = False - self.downstream_to_upstream_htlc = {} - self.dont_expire_htlcs = {} - self.dont_settle_htlcs = {} - self.hold_invoice_callbacks = {} - self._payment_bundles_pkey_to_canon = {} # type: Dict[bytes, bytes] - self._payment_bundles_canon_to_pkeylist = {} # type: Dict[bytes, Sequence[bytes]] - self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0 - self._channel_sending_capacity_lock = asyncio.Lock() - self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}") + self.logger.info(f"created LNWallet[{name}] with nodeID={self.node_keypair.pubkey.hex()}") - def clear_invoices_cache(self): - pass + def _add_channel(self, chan: Channel): + self._channels[chan.channel_id] = chan + chan.lnworker = self - def get_invoice_status(self, key): - pass - - @property - def lock(self): - return noop_lock() - - @property - def channel_db(self): - return self.network.channel_db if self.network else None - - def uses_trampoline(self): - return not bool(self.channel_db) - - @property - def channels(self): - return self._channels - - @property - def peers(self): - return self._peers - - def get_channel_by_short_id(self, short_channel_id): - with self.lock: - for chan in self._channels.values(): - if chan.short_channel_id == short_channel_id: - return chan - - def channel_state_changed(self, chan): - pass + @LNWallet.features.setter + def features(self, value): + self.lnpeermgr.features = value def save_channel(self, chan): - print("Ignoring channel save") + pass + #print("Ignoring channel save") def diagnostic_name(self): return self.name @@ -290,69 +246,6 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): budget=PaymentFeeBudget.from_invoice_amount(invoice_amount_msat=amount_msat, config=self.config), )] - get_payments = LNWallet.get_payments - get_payment_secret = LNWallet.get_payment_secret - get_payment_info = LNWallet.get_payment_info - save_payment_info = LNWallet.save_payment_info - set_invoice_status = LNWallet.set_invoice_status - set_request_status = LNWallet.set_request_status - set_payment_status = LNWallet.set_payment_status - get_payment_status = LNWallet.get_payment_status - htlc_fulfilled = LNWallet.htlc_fulfilled - htlc_failed = LNWallet.htlc_failed - save_preimage = LNWallet.save_preimage - get_preimage = LNWallet.get_preimage - create_route_for_single_htlc = LNWallet.create_route_for_single_htlc - create_routes_for_payment = LNWallet.create_routes_for_payment - _check_bolt11_invoice = LNWallet._check_bolt11_invoice - pay_to_route = LNWallet.pay_to_route - pay_to_node = LNWallet.pay_to_node - pay_invoice = LNWallet.pay_invoice - force_close_channel = LNWallet.force_close_channel - schedule_force_closing = LNWallet.schedule_force_closing - on_peer_successfully_established = LNWallet.on_peer_successfully_established - get_channel_by_id = LNWallet.get_channel_by_id - channels_for_peer = LNWallet.channels_for_peer - calc_routing_hints_for_invoice = LNWallet.calc_routing_hints_for_invoice - get_channels_for_receiving = LNWallet.get_channels_for_receiving - handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc - is_trampoline_peer = LNWallet.is_trampoline_peer - wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed - #on_event_proxy_set = LNWallet.on_event_proxy_set - _decode_channel_update_msg = LNWallet._decode_channel_update_msg - _handle_chanupd_from_failed_htlc = LNWallet._handle_chanupd_from_failed_htlc - is_forwarded_htlc = LNWallet.is_forwarded_htlc - notify_upstream_peer = LNWallet.notify_upstream_peer - _force_close_channel = LNWallet._force_close_channel - suggest_payment_splits = LNWallet.suggest_payment_splits - register_hold_invoice = LNWallet.register_hold_invoice - unregister_hold_invoice = LNWallet.unregister_hold_invoice - add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice - - update_or_create_mpp_with_received_htlc = LNWallet.update_or_create_mpp_with_received_htlc - set_mpp_resolution = LNWallet.set_mpp_resolution - get_mpp_amounts = LNWallet.get_mpp_amounts - bundle_payments = LNWallet.bundle_payments - get_payment_bundle = LNWallet.get_payment_bundle - _get_payment_key = LNWallet._get_payment_key - save_forwarding_failure = LNWallet.save_forwarding_failure - get_forwarding_failure = LNWallet.get_forwarding_failure - maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding - current_target_feerate_per_kw = LNWallet.current_target_feerate_per_kw - current_low_feerate_per_kw_srk_channel = LNWallet.current_low_feerate_per_kw_srk_channel - create_onion_for_route = LNWallet.create_onion_for_route - maybe_forward_htlc_set = LNWallet.maybe_forward_htlc_set - _maybe_forward_htlc = LNWallet._maybe_forward_htlc - _maybe_forward_trampoline = LNWallet._maybe_forward_trampoline - _maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created = LNWallet._maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created - set_htlc_set_error = LNWallet.set_htlc_set_error - is_payment_bundle_complete = LNWallet.is_payment_bundle_complete - delete_payment_bundle = LNWallet.delete_payment_bundle - _process_htlc_log = LNWallet._process_htlc_log - _get_invoice_features = LNWallet._get_invoice_features - receive_requires_jit_channel = LNWallet.receive_requires_jit_channel - can_get_zeroconf_channel = LNWallet.can_get_zeroconf_channel - class MockTransport: def __init__(self, name): @@ -667,25 +560,24 @@ class TestPeerDirect(TestPeer): def prepare_peers( self, alice_channel: Channel, bob_channel: Channel, - *, k1: Keypair = None, k2: Keypair = None, ): - if k1 is None: - k1 = keypair() - if k2 is None: - k2 = keypair() + q1, q2 = asyncio.Queue(), asyncio.Queue() + w1 = MockLNWallet(tx_queue=q1, name=bob_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS) + w2 = MockLNWallet(tx_queue=q2, name=alice_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS) + k1 = w1.node_keypair + k2 = w2.node_keypair alice_channel.node_id = k2.pubkey bob_channel.node_id = k1.pubkey alice_channel.storage['node_id'] = alice_channel.node_id bob_channel.storage['node_id'] = bob_channel.node_id t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name) - q1, q2 = asyncio.Queue(), asyncio.Queue() - w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1, name=bob_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS) - w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2, name=alice_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS) + w1._add_channel(alice_channel) + w2._add_channel(bob_channel) self._lnworkers_created.extend([w1, w2]) p1 = PeerInTests(w1, k2.pubkey, t1) p2 = PeerInTests(w2, k1.pubkey, t2) - w1._peers[p1.pubkey] = p1 - w2._peers[p2.pubkey] = p2 + w1.lnpeermgr._peers[p1.pubkey] = p1 + w2.lnpeermgr._peers[p2.pubkey] = p2 # mark_open won't work if state is already OPEN. # so set it to FUNDED alice_channel._state = ChannelState.FUNDED @@ -790,10 +682,9 @@ class TestPeerDirect(TestPeer): ----sig--> """ chan_AB, chan_BA = create_test_channels() - k1, k2 = keypair(), keypair() # note: we don't start peer.htlc_switch() so that the fake htlcs are left alone. async def f(): - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2) + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA) async with OldTaskGroup() as group: await group.spawn(p1._message_loop()) await group.spawn(p2._message_loop()) @@ -807,7 +698,7 @@ class TestPeerDirect(TestPeer): await group.cancel_remaining() # simulating disconnection. recreate transports. self.logger.info("simulating disconnection. recreating transports.") - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2) + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA) for chan in (chan_AB, chan_BA): chan.peer_state = PeerState.DISCONNECTED async with OldTaskGroup() as group: @@ -846,10 +737,9 @@ class TestPeerDirect(TestPeer): ----rev--> """ chan_AB, chan_BA = create_test_channels() - k1, k2 = keypair(), keypair() # note: we don't start peer.htlc_switch() so that the fake htlcs are left alone. async def f(): - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2) + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA) async with OldTaskGroup() as group: await group.spawn(p1._message_loop()) await group.spawn(p2._message_loop()) @@ -864,7 +754,7 @@ class TestPeerDirect(TestPeer): await group.cancel_remaining() # simulating disconnection. recreate transports. self.logger.info("simulating disconnection. recreating transports.") - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2) + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA) for chan in (chan_AB, chan_BA): chan.peer_state = PeerState.DISCONNECTED async with OldTaskGroup() as group: @@ -1788,7 +1678,7 @@ class TestPeerDirect(TestPeer): with self.assertRaises(NoPathFound) as e: await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr) - peer = w1.peers[route[0].node_id] + peer = w1.lnpeermgr._peers[route[0].node_id] # AssertionError is ok since we shouldn't use old routes, and the # route finding should fail when channel is closed async def f(): @@ -2126,12 +2016,19 @@ class TestPeerDirect(TestPeer): class TestPeerForwarding(TestPeer): def prepare_chans_and_peers_in_graph(self, graph_definition) -> Graph: - keys = {k: keypair() for k in graph_definition} + workers = {} # type: Dict[str, MockLNWallet] txs_queues = {k: asyncio.Queue() for k in graph_definition} + + # create workers + for a, definition in graph_definition.items(): + workers[a] = MockLNWallet(tx_queue=txs_queues[a], name=a, has_anchors=self.TEST_ANCHOR_CHANNELS) + self._lnworkers_created.extend(list(workers.values())) + keys = {name: w.node_keypair for name, w in workers.items()} + channels = {} # type: Dict[Tuple[str, str], Channel] transports = {} - workers = {} # type: Dict[str, MockLNWallet] peers = {} + # create channels for a, definition in graph_definition.items(): for b, channel_def in definition.get('channels', {}).items(): @@ -2145,6 +2042,8 @@ class TestPeerForwarding(TestPeer): anchor_outputs=self.TEST_ANCHOR_CHANNELS ) channels[(a, b)], channels[(b, a)] = channel_ab, channel_ba + workers[a]._add_channel(channel_ab) + workers[b]._add_channel(channel_ba) transport_ab, transport_ba = transport_pair(keys[a], keys[b], channel_ab.name, channel_ba.name) transports[(a, b)], transports[(b, a)] = transport_ab, transport_ba # set fees @@ -2153,12 +2052,6 @@ class TestPeerForwarding(TestPeer): channel_ba.forwarding_fee_proportional_millionths = channel_def['remote_fee_rate_millionths'] channel_ba.forwarding_fee_base_msat = channel_def['remote_base_fee_msat'] - # create workers and peers - for a, definition in graph_definition.items(): - channels_of_node = [c for k, c in channels.items() if k[0] == a] - workers[a] = MockLNWallet(local_keypair=keys[a], chans=channels_of_node, tx_queue=txs_queues[a], name=a, has_anchors=self.TEST_ANCHOR_CHANNELS) - self._lnworkers_created.extend(list(workers.values())) - # create peers for ab in channels.keys(): peers[ab] = Peer(workers[ab[0]], keys[ab[1]].pubkey, transports[ab]) @@ -2167,7 +2060,7 @@ class TestPeerForwarding(TestPeer): for a, w in workers.items(): for ab, peer_ab in peers.items(): if ab[0] == a: - w._peers[peer_ab.pubkey] = peer_ab + w.lnpeermgr._peers[peer_ab.pubkey] = peer_ab # set forwarding properties for a, definition in graph_definition.items(): diff --git a/tests/test_onion_message.py b/tests/test_onion_message.py index 76e09a2c8..ce23b98d6 100644 --- a/tests/test_onion_message.py +++ b/tests/test_onion_message.py @@ -352,9 +352,8 @@ class TestOnionMessageManager(ElectrumTestCase): async def test_request_and_reply(self): n = MockNetwork() - k = keypair() q1, q2 = asyncio.Queue(), asyncio.Queue() - lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test_request_and_reply', has_anchors=False) + lnw = MockLNWallet(tx_queue=q1, name='test_request_and_reply', has_anchors=False) def slow(*args, **kwargs): time.sleep(2*TIME_STEP) @@ -369,10 +368,10 @@ class TestOnionMessageManager(ElectrumTestCase): rkey1 = bfh('0102030405060708') rkey2 = bfh('0102030405060709') - lnw.peers[self.alice.pubkey] = MockPeer(self.alice.pubkey) - lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=slow) - lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(withreply, rkey1)) - lnw.peers[self.dave.pubkey] = MockPeer(self.dave.pubkey, on_send_message=partial(slowwithreply, rkey2)) + lnw.lnpeermgr._peers[self.alice.pubkey] = MockPeer(self.alice.pubkey) + lnw.lnpeermgr._peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=slow) + lnw.lnpeermgr._peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(withreply, rkey1)) + lnw.lnpeermgr._peers[self.dave.pubkey] = MockPeer(self.dave.pubkey, on_send_message=partial(slowwithreply, rkey2)) t = OnionMessageManager(lnw) t.start_network(network=n) @@ -401,7 +400,8 @@ class TestOnionMessageManager(ElectrumTestCase): async def test_forward(self): n = MockNetwork() q1 = asyncio.Queue() - lnw = MockLNWallet(local_keypair=self.alice, chans=[], tx_queue=q1, name='alice', has_anchors=False) + lnw = MockLNWallet(tx_queue=q1, name='alice', has_anchors=False) + lnw.node_keypair = self.alice self.was_sent = False @@ -414,8 +414,8 @@ class TestOnionMessageManager(ElectrumTestCase): self.assertEqual(message_type, 'onion_message') self.assertEqual(payload['onion_message_packet'], kwargs['onion_message_packet']) - lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=partial(on_send, 'bob')) - lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(on_send, 'carol')) + lnw.lnpeermgr._peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=partial(on_send, 'bob')) + lnw.lnpeermgr._peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(on_send, 'carol')) t = OnionMessageManager(lnw) t.start_network(network=n) @@ -438,7 +438,8 @@ class TestOnionMessageManager(ElectrumTestCase): async def test_receive_unsolicited(self): n = MockNetwork() q1 = asyncio.Queue() - lnw = MockLNWallet(local_keypair=self.dave, chans=[], tx_queue=q1, name='dave', has_anchors=False) + lnw = MockLNWallet(tx_queue=q1, name='dave', has_anchors=False) + lnw.node_keypair = self.dave t = OnionMessageManager(lnw) t.start_network(network=n)