lnpeer: more forwarding is now event-driven
This should make unit tests less reliant on sleeps.
This commit is contained in:
@@ -112,6 +112,7 @@ class Peer(Logger):
|
||||
self._htlc_switch_iterstart_event = asyncio.Event()
|
||||
self._htlc_switch_iterdone_event = asyncio.Event()
|
||||
self._received_revack_event = asyncio.Event()
|
||||
self.downstream_htlc_resolved_event = asyncio.Event()
|
||||
|
||||
def send_message(self, message_name: str, **kwargs):
|
||||
assert type(message_name) is str
|
||||
@@ -1198,16 +1199,17 @@ class Peer(Logger):
|
||||
chan.receive_fail_htlc(htlc_id, error_bytes=reason) # TODO handle exc and maybe fail channel (e.g. bad htlc_id)
|
||||
self.maybe_send_commitment(chan)
|
||||
|
||||
def maybe_send_commitment(self, chan: Channel):
|
||||
def maybe_send_commitment(self, chan: Channel) -> bool:
|
||||
# REMOTE should revoke first before we can sign a new ctx
|
||||
if chan.hm.is_revack_pending(REMOTE):
|
||||
return
|
||||
return False
|
||||
# if there are no changes, we will not (and must not) send a new commitment
|
||||
if not chan.has_pending_changes(REMOTE):
|
||||
return
|
||||
return False
|
||||
self.logger.info(f'send_commitment. chan {chan.short_channel_id}. ctn: {chan.get_next_ctn(REMOTE)}.')
|
||||
sig_64, htlc_sigs = chan.sign_next_commitment()
|
||||
self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs))
|
||||
return True
|
||||
|
||||
def pay(self, *,
|
||||
route: 'LNPaymentRoute',
|
||||
@@ -1424,6 +1426,7 @@ class Peer(Logger):
|
||||
except BaseException as e:
|
||||
self.logger.info(f"failed to forward htlc: error sending message. {e}")
|
||||
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
|
||||
next_peer.maybe_send_commitment(next_chan)
|
||||
return next_chan_scid, next_htlc.htlc_id
|
||||
|
||||
def maybe_forward_trampoline(
|
||||
@@ -1845,11 +1848,14 @@ class Peer(Logger):
|
||||
self._htlc_switch_iterdone_event.set()
|
||||
self._htlc_switch_iterdone_event.clear()
|
||||
# We poll every 0.1 sec to check if there is work to do,
|
||||
# or we can be woken up when receiving a revack.
|
||||
# TODO when forwarding, we should also be woken up when there are
|
||||
# certain events with the downstream peer
|
||||
# or we can also be triggered via events.
|
||||
# When forwarding an HTLC originating from this peer (the upstream),
|
||||
# we can get triggered for events that happen on the downstream peer.
|
||||
# TODO: trampoline forwarding relies on the polling
|
||||
async with ignore_after(0.1):
|
||||
await self._received_revack_event.wait()
|
||||
async with TaskGroup(wait=any) as group:
|
||||
await group.spawn(self._received_revack_event.wait())
|
||||
await group.spawn(self.downstream_htlc_resolved_event.wait())
|
||||
self._htlc_switch_iterstart_event.set()
|
||||
self._htlc_switch_iterstart_event.clear()
|
||||
self.ping_if_required()
|
||||
@@ -1861,6 +1867,8 @@ class Peer(Logger):
|
||||
done = set()
|
||||
unfulfilled = chan.unfulfilled_htlcs
|
||||
for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_info) in unfulfilled.items():
|
||||
if forwarding_info:
|
||||
self.lnworker.downstream_htlc_to_upstream_peer_map[forwarding_info] = self.pubkey
|
||||
if not chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id):
|
||||
continue
|
||||
htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id)
|
||||
@@ -1886,6 +1894,7 @@ class Peer(Logger):
|
||||
error_bytes = construct_onion_error(e, onion_packet, our_onion_private_key=self.privkey)
|
||||
if fw_info:
|
||||
unfulfilled[htlc_id] = local_ctn, remote_ctn, onion_packet_hex, fw_info
|
||||
self.lnworker.downstream_htlc_to_upstream_peer_map[fw_info] = self.pubkey
|
||||
elif preimage or error_reason or error_bytes:
|
||||
if preimage:
|
||||
if not self.lnworker.enable_htlc_settle:
|
||||
@@ -1904,7 +1913,10 @@ class Peer(Logger):
|
||||
done.add(htlc_id)
|
||||
# cleanup
|
||||
for htlc_id in done:
|
||||
unfulfilled.pop(htlc_id)
|
||||
local_ctn, remote_ctn, onion_packet_hex, forwarding_info = unfulfilled.pop(htlc_id)
|
||||
if forwarding_info:
|
||||
self.lnworker.downstream_htlc_to_upstream_peer_map.pop(forwarding_info, None)
|
||||
self.maybe_send_commitment(chan)
|
||||
|
||||
def _maybe_cleanup_received_htlcs_pending_removal(self) -> None:
|
||||
done = set()
|
||||
|
||||
@@ -647,6 +647,8 @@ class LNWallet(LNWorker):
|
||||
self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
|
||||
|
||||
self.trampoline_forwarding_failures = {} # todo: should be persisted
|
||||
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
|
||||
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]
|
||||
|
||||
def has_deterministic_node_id(self):
|
||||
return bool(self.db.get('lightning_xprv'))
|
||||
@@ -1847,8 +1849,23 @@ class LNWallet(LNWorker):
|
||||
info = info._replace(status=status)
|
||||
self.save_payment_info(info)
|
||||
|
||||
def _on_maybe_forwarded_htlc_resolved(self, chan: Channel, htlc_id: int) -> None:
|
||||
"""Called when an HTLC we offered on chan gets irrevocably fulfilled or failed.
|
||||
If we find this was a forwarded HTLC, the upstream peer is notified.
|
||||
"""
|
||||
fw_info = chan.short_channel_id.hex(), htlc_id
|
||||
upstream_peer_pubkey = self.downstream_htlc_to_upstream_peer_map.get(fw_info)
|
||||
if not upstream_peer_pubkey:
|
||||
return
|
||||
upstream_peer = self.peers.get(upstream_peer_pubkey)
|
||||
if not upstream_peer:
|
||||
return
|
||||
upstream_peer.downstream_htlc_resolved_event.set()
|
||||
upstream_peer.downstream_htlc_resolved_event.clear()
|
||||
|
||||
def htlc_fulfilled(self, chan: Channel, payment_hash: bytes, htlc_id: int):
|
||||
util.trigger_callback('htlc_fulfilled', payment_hash, chan, htlc_id)
|
||||
self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id)
|
||||
q = self.sent_htlcs.get(payment_hash)
|
||||
if q:
|
||||
route, payment_secret, amount_msat, bucket_msat, amount_receiver_msat = self.sent_htlcs_routes[(payment_hash, chan.short_channel_id, htlc_id)]
|
||||
@@ -1871,6 +1888,7 @@ class LNWallet(LNWorker):
|
||||
failure_message: Optional['OnionRoutingFailure']):
|
||||
|
||||
util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id)
|
||||
self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id)
|
||||
q = self.sent_htlcs.get(payment_hash)
|
||||
if q:
|
||||
# detect if it is part of a bucket
|
||||
|
||||
@@ -153,6 +153,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
self.inflight_payments = set()
|
||||
self.preimages = {}
|
||||
self.stopping_soon = False
|
||||
self.downstream_htlc_to_upstream_peer_map = {}
|
||||
|
||||
self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}")
|
||||
|
||||
@@ -241,6 +242,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
on_proxy_changed = LNWallet.on_proxy_changed
|
||||
_decode_channel_update_msg = LNWallet._decode_channel_update_msg
|
||||
_handle_chanupd_from_failed_htlc = LNWallet._handle_chanupd_from_failed_htlc
|
||||
_on_maybe_forwarded_htlc_resolved = LNWallet._on_maybe_forwarded_htlc_resolved
|
||||
|
||||
|
||||
class MockTransport:
|
||||
|
||||
Reference in New Issue
Block a user