Refactor trampoline forwarding and hold invoices.
- maybe_fulfill_htlc returns a forwarding callback that covers both cases. - previously, the callback of hold invoices was called as a side-effect of lnworker.check_mpp_status. - the same data structures (lnworker.trampoline_forwardings, lnworker.trampoline_forwarding_errors) are used for both trampoline forwardings and hold invoices. - maybe_fulfill_htlc still recursively calls itself to perform checks on trampoline onion. This is ugly, but ugliness is now contained to that method.
This commit is contained in:
@@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set
|
||||
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable
|
||||
from datetime import datetime
|
||||
import functools
|
||||
|
||||
@@ -1668,7 +1668,8 @@ class Peer(Logger):
|
||||
next_peer.maybe_send_commitment(next_chan)
|
||||
return next_chan_scid, next_htlc.htlc_id
|
||||
|
||||
def maybe_forward_trampoline(
|
||||
@log_exceptions
|
||||
async def maybe_forward_trampoline(
|
||||
self, *,
|
||||
payment_hash: bytes,
|
||||
cltv_expiry: int,
|
||||
@@ -1713,48 +1714,34 @@ class Peer(Logger):
|
||||
trampoline_fee = total_msat - amt_to_forward
|
||||
self.logger.info(f'trampoline cltv and fee: {trampoline_cltv_delta, trampoline_fee}')
|
||||
|
||||
@log_exceptions
|
||||
async def forward_trampoline_payment():
|
||||
try:
|
||||
await self.lnworker.pay_to_node(
|
||||
node_pubkey=outgoing_node_id,
|
||||
payment_hash=payment_hash,
|
||||
payment_secret=payment_secret,
|
||||
amount_to_pay=amt_to_forward,
|
||||
min_cltv_expiry=cltv_from_onion,
|
||||
r_tags=[],
|
||||
invoice_features=invoice_features,
|
||||
fwd_trampoline_onion=next_trampoline_onion,
|
||||
fwd_trampoline_fee=trampoline_fee,
|
||||
fwd_trampoline_cltv_delta=trampoline_cltv_delta,
|
||||
attempts=1)
|
||||
except OnionRoutingFailure as e:
|
||||
# FIXME: cannot use payment_hash as key
|
||||
self.lnworker.trampoline_forwarding_failures[payment_hash] = e
|
||||
except PaymentFailure as e:
|
||||
# FIXME: adapt the error code
|
||||
error_reason = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
|
||||
self.lnworker.trampoline_forwarding_failures[payment_hash] = error_reason
|
||||
|
||||
# remove from list of payments, so that another attempt can be initiated
|
||||
self.lnworker.trampoline_forwardings.remove(payment_hash)
|
||||
|
||||
# add to list of ongoing payments
|
||||
self.lnworker.trampoline_forwardings.add(payment_hash)
|
||||
# clear previous failures
|
||||
self.lnworker.trampoline_forwarding_failures.pop(payment_hash, None)
|
||||
# start payment
|
||||
asyncio.ensure_future(forward_trampoline_payment())
|
||||
try:
|
||||
await self.lnworker.pay_to_node(
|
||||
node_pubkey=outgoing_node_id,
|
||||
payment_hash=payment_hash,
|
||||
payment_secret=payment_secret,
|
||||
amount_to_pay=amt_to_forward,
|
||||
min_cltv_expiry=cltv_from_onion,
|
||||
r_tags=[],
|
||||
invoice_features=invoice_features,
|
||||
fwd_trampoline_onion=next_trampoline_onion,
|
||||
fwd_trampoline_fee=trampoline_fee,
|
||||
fwd_trampoline_cltv_delta=trampoline_cltv_delta,
|
||||
attempts=1)
|
||||
except OnionRoutingFailure as e:
|
||||
raise
|
||||
except PaymentFailure as e:
|
||||
# FIXME: adapt the error code
|
||||
raise OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
|
||||
|
||||
def maybe_fulfill_htlc(
|
||||
self, *,
|
||||
chan: Channel,
|
||||
htlc: UpdateAddHtlc,
|
||||
processed_onion: ProcessedOnionPacket,
|
||||
is_trampoline: bool = False) -> Optional[bytes]:
|
||||
|
||||
onion_packet_bytes: bytes,
|
||||
is_trampoline: bool = False) -> Tuple[Optional[bytes], Optional[Callable]]:
|
||||
"""As a final recipient of an HTLC, decide if we should fulfill it.
|
||||
Return preimage or None
|
||||
Return (preimage, forwarding_callback) with at most a single element not None
|
||||
"""
|
||||
def log_fail_reason(reason: str):
|
||||
self.logger.info(f"maybe_fulfill_htlc. will FAIL HTLC: chan {chan.short_channel_id}. "
|
||||
@@ -1810,19 +1797,55 @@ class Peer(Logger):
|
||||
log_fail_reason(f"'payment_secret' missing from onion")
|
||||
raise exc_incorrect_or_unknown_pd
|
||||
|
||||
payment_status = self.lnworker.check_received_htlc(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat)
|
||||
payment_status = self.lnworker.check_mpp_status(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat)
|
||||
if payment_status is None:
|
||||
return None
|
||||
return None, None
|
||||
elif payment_status is False:
|
||||
log_fail_reason(f"MPP_TIMEOUT")
|
||||
raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
|
||||
else:
|
||||
assert payment_status is True
|
||||
|
||||
payment_hash = htlc.payment_hash
|
||||
preimage = self.lnworker.get_preimage(payment_hash)
|
||||
hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash)
|
||||
if not preimage and hold_invoice_callback:
|
||||
if preimage:
|
||||
return preimage, None
|
||||
else:
|
||||
# for hold invoices, trigger callback
|
||||
cb, timeout = hold_invoice_callback
|
||||
if int(time.time()) < timeout:
|
||||
return None, lambda: cb(payment_hash)
|
||||
else:
|
||||
raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
|
||||
|
||||
# if there is a trampoline_onion, maybe_fulfill_htlc will be called again
|
||||
if processed_onion.trampoline_onion_packet:
|
||||
# TODO: we should check that all trampoline_onions are the same
|
||||
return None
|
||||
|
||||
trampoline_onion = self.process_onion_packet(
|
||||
processed_onion.trampoline_onion_packet,
|
||||
payment_hash=payment_hash,
|
||||
onion_packet_bytes=onion_packet_bytes,
|
||||
is_trampoline=True)
|
||||
if trampoline_onion.are_we_final:
|
||||
# trampoline- we are final recipient of HTLC
|
||||
preimage, cb = self.maybe_fulfill_htlc(
|
||||
chan=chan,
|
||||
htlc=htlc,
|
||||
processed_onion=trampoline_onion,
|
||||
onion_packet_bytes=onion_packet_bytes,
|
||||
is_trampoline=True)
|
||||
assert cb is None
|
||||
return preimage, None
|
||||
else:
|
||||
callback = lambda: self.maybe_forward_trampoline(
|
||||
payment_hash=payment_hash,
|
||||
cltv_expiry=htlc.cltv_expiry, # TODO: use max or enforce same value across mpp parts
|
||||
outer_onion=processed_onion,
|
||||
trampoline_onion=trampoline_onion)
|
||||
return None, callback
|
||||
|
||||
# TODO don't accept payments twice for same invoice
|
||||
# TODO check invoice expiry
|
||||
@@ -1845,7 +1868,7 @@ class Peer(Logger):
|
||||
if preimage:
|
||||
self.logger.info(f"maybe_fulfill_htlc. will FULFILL HTLC: chan {chan.short_channel_id}. htlc={str(htlc)}")
|
||||
self.lnworker.set_request_status(htlc.payment_hash, PR_PAID)
|
||||
return preimage
|
||||
return preimage, None
|
||||
|
||||
def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes):
|
||||
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
|
||||
@@ -2340,42 +2363,36 @@ class Peer(Logger):
|
||||
onion_packet_bytes=onion_packet_bytes)
|
||||
if processed_onion.are_we_final:
|
||||
# either we are final recipient; or if trampoline, see cases below
|
||||
preimage = self.maybe_fulfill_htlc(
|
||||
preimage, forwarding_callback = self.maybe_fulfill_htlc(
|
||||
chan=chan,
|
||||
htlc=htlc,
|
||||
processed_onion=processed_onion)
|
||||
processed_onion=processed_onion,
|
||||
onion_packet_bytes=onion_packet_bytes)
|
||||
|
||||
if processed_onion.trampoline_onion_packet:
|
||||
# trampoline- recipient or forwarding
|
||||
if forwarding_callback:
|
||||
if not forwarding_info:
|
||||
trampoline_onion = self.process_onion_packet(
|
||||
processed_onion.trampoline_onion_packet,
|
||||
payment_hash=payment_hash,
|
||||
onion_packet_bytes=onion_packet_bytes,
|
||||
is_trampoline=True)
|
||||
if trampoline_onion.are_we_final:
|
||||
# trampoline- we are final recipient of HTLC
|
||||
preimage = self.maybe_fulfill_htlc(
|
||||
chan=chan,
|
||||
htlc=htlc,
|
||||
processed_onion=trampoline_onion,
|
||||
is_trampoline=True)
|
||||
# trampoline- HTLC we are supposed to forward, but haven't forwarded yet
|
||||
if not self.lnworker.enable_htlc_forwarding:
|
||||
pass
|
||||
elif payment_hash in self.lnworker.trampoline_forwardings:
|
||||
# we are already forwarding this payment
|
||||
self.logger.info(f"we are already forwarding this.")
|
||||
else:
|
||||
# trampoline- HTLC we are supposed to forward, but haven't forwarded yet
|
||||
if not self.lnworker.enable_htlc_forwarding:
|
||||
return None, None, None
|
||||
|
||||
if payment_hash in self.lnworker.trampoline_forwardings:
|
||||
self.logger.info(f"we are already forwarding this.")
|
||||
# we are already forwarding this payment
|
||||
return None, True, None
|
||||
|
||||
self.maybe_forward_trampoline(
|
||||
payment_hash=payment_hash,
|
||||
cltv_expiry=htlc.cltv_expiry, # TODO: use max or enforce same value across mpp parts
|
||||
outer_onion=processed_onion,
|
||||
trampoline_onion=trampoline_onion)
|
||||
# return True so that this code gets executed only once
|
||||
# add to list of ongoing payments
|
||||
self.lnworker.trampoline_forwardings.add(payment_hash)
|
||||
# clear previous failures
|
||||
self.lnworker.trampoline_forwarding_failures.pop(payment_hash, None)
|
||||
async def wrapped_callback():
|
||||
forwarding_coro = forwarding_callback()
|
||||
try:
|
||||
await forwarding_coro
|
||||
except Exception as e:
|
||||
# FIXME: cannot use payment_hash as key
|
||||
self.lnworker.trampoline_forwarding_failures[payment_hash] = e
|
||||
finally:
|
||||
# remove from list of payments, so that another attempt can be initiated
|
||||
self.lnworker.trampoline_forwardings.remove(payment_hash)
|
||||
asyncio.ensure_future(wrapped_callback())
|
||||
return None, True, None
|
||||
else:
|
||||
# trampoline- HTLC we are supposed to forward, and have already forwarded
|
||||
|
||||
@@ -1922,16 +1922,16 @@ class LNWallet(LNWorker):
|
||||
if write_to_disk:
|
||||
self.wallet.save_db()
|
||||
|
||||
def check_received_htlc(
|
||||
self, payment_secret: bytes,
|
||||
short_channel_id: ShortChannelID,
|
||||
htlc: UpdateAddHtlc,
|
||||
expected_msat: int,
|
||||
|
||||
def check_mpp_status(
|
||||
self, payment_secret: bytes,
|
||||
short_channel_id: ShortChannelID,
|
||||
htlc: UpdateAddHtlc,
|
||||
expected_msat: int,
|
||||
) -> Optional[bool]:
|
||||
""" return MPP status: True (accepted), False (expired) or None (waiting)
|
||||
"""
|
||||
payment_hash = htlc.payment_hash
|
||||
|
||||
self.update_mpp_with_received_htlc(payment_secret, short_channel_id, htlc, expected_msat)
|
||||
is_expired, is_accepted = self.get_mpp_status(payment_secret)
|
||||
if not is_accepted and not is_expired:
|
||||
@@ -1944,19 +1944,7 @@ class LNWallet(LNWorker):
|
||||
elif self.stopping_soon:
|
||||
is_expired = True # try to time out pending HTLCs before shutting down
|
||||
elif all([self.is_mpp_amount_reached(x) for x in payment_secrets]):
|
||||
preimage = self.get_preimage(payment_hash)
|
||||
hold_invoice_callback = self.hold_invoice_callbacks.get(payment_hash)
|
||||
if not preimage and hold_invoice_callback:
|
||||
# for hold invoices, trigger callback
|
||||
cb, timeout = hold_invoice_callback
|
||||
if int(time.time()) < timeout:
|
||||
cb(payment_hash)
|
||||
else:
|
||||
is_expired = True
|
||||
else:
|
||||
# note: preimage will be None for outer trampoline onion
|
||||
is_accepted = True
|
||||
|
||||
is_accepted = True
|
||||
elif time.time() - first_timestamp > self.MPP_EXPIRY:
|
||||
is_expired = True
|
||||
|
||||
|
||||
@@ -251,7 +251,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
set_request_status = LNWallet.set_request_status
|
||||
set_payment_status = LNWallet.set_payment_status
|
||||
get_payment_status = LNWallet.get_payment_status
|
||||
check_received_htlc = LNWallet.check_received_htlc
|
||||
check_mpp_status = LNWallet.check_mpp_status
|
||||
htlc_fulfilled = LNWallet.htlc_fulfilled
|
||||
htlc_failed = LNWallet.htlc_failed
|
||||
save_preimage = LNWallet.save_preimage
|
||||
@@ -764,7 +764,7 @@ class TestPeer(ElectrumTestCase):
|
||||
if test_hold_invoice:
|
||||
payment_hash = lnaddr.paymenthash
|
||||
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
|
||||
def cb(payment_hash):
|
||||
async def cb(payment_hash):
|
||||
if not test_hold_timeout:
|
||||
w2.save_preimage(payment_hash, preimage)
|
||||
timeout = 1 if test_hold_timeout else 60
|
||||
|
||||
Reference in New Issue
Block a user