lnpeer and lnworker cleanup:
- rename trampoline_forwardings -> final_onion_forwardings, because this dict is used for both trampoline and hold invoices - remove timeout from hold_invoice_callbacks (redundant with invoice) - add test_failure boolean parameter to TestPeer._test_simple_payment, in order to test correct propagation of OnionRoutingFailures. - maybe_fulfill_htlc: raise an OnionRoutingFailure if we do not have the preimage for a payment that does not have a hold invoice callback. Without this, the above unit tests stall when we use test_failure=True
This commit is contained in:
@@ -1887,12 +1887,7 @@ class Peer(Logger):
|
||||
if preimage:
|
||||
return preimage, None
|
||||
else:
|
||||
# for hold invoices, trigger callback
|
||||
cb, timeout = hold_invoice_callback
|
||||
if int(time.time()) < timeout:
|
||||
return None, lambda: cb(payment_hash)
|
||||
else:
|
||||
raise exc_incorrect_or_unknown_pd
|
||||
return None, lambda: hold_invoice_callback(payment_hash)
|
||||
|
||||
# TODO don't accept payments twice for same invoice
|
||||
# TODO check invoice expiry
|
||||
@@ -1903,8 +1898,8 @@ class Peer(Logger):
|
||||
|
||||
preimage = self.lnworker.get_preimage(payment_hash)
|
||||
if not preimage:
|
||||
self.logger.info(f"missing callback {payment_hash.hex()}")
|
||||
return None, None
|
||||
self.logger.info(f"missing preimage and no hold invoice callback {payment_hash.hex()}")
|
||||
raise exc_incorrect_or_unknown_pd
|
||||
|
||||
expected_payment_secrets = [self.lnworker.get_payment_secret(htlc.payment_hash)]
|
||||
expected_payment_secrets.append(derive_payment_secret_from_payment_preimage(preimage)) # legacy secret for old invoices
|
||||
@@ -2424,23 +2419,23 @@ class Peer(Logger):
|
||||
# trampoline- HTLC we are supposed to forward, but haven't forwarded yet
|
||||
if not self.lnworker.enable_htlc_forwarding:
|
||||
pass
|
||||
elif payment_key in self.lnworker.trampoline_forwardings:
|
||||
elif payment_key in self.lnworker.final_onion_forwardings:
|
||||
# we are already forwarding this payment
|
||||
self.logger.info(f"we are already forwarding this.")
|
||||
else:
|
||||
# add to list of ongoing payments
|
||||
self.lnworker.trampoline_forwardings.add(payment_key)
|
||||
self.lnworker.final_onion_forwardings.add(payment_key)
|
||||
# clear previous failures
|
||||
self.lnworker.trampoline_forwarding_failures.pop(payment_key, None)
|
||||
self.lnworker.final_onion_forwarding_failures.pop(payment_key, None)
|
||||
async def wrapped_callback():
|
||||
forwarding_coro = forwarding_callback()
|
||||
try:
|
||||
await forwarding_coro
|
||||
except OnionRoutingFailure as e:
|
||||
self.lnworker.trampoline_forwarding_failures[payment_key] = e
|
||||
self.lnworker.final_onion_forwarding_failures[payment_key] = e
|
||||
finally:
|
||||
# remove from list of payments, so that another attempt can be initiated
|
||||
self.lnworker.trampoline_forwardings.remove(payment_key)
|
||||
self.lnworker.final_onion_forwardings.remove(payment_key)
|
||||
asyncio.ensure_future(wrapped_callback())
|
||||
fw_info = payment_key.hex()
|
||||
return None, fw_info, None
|
||||
@@ -2449,7 +2444,7 @@ class Peer(Logger):
|
||||
payment_key = bytes.fromhex(forwarding_info)
|
||||
preimage = self.lnworker.get_preimage(payment_hash)
|
||||
# get (and not pop) failure because the incoming payment might be multi-part
|
||||
error_reason = self.lnworker.trampoline_forwarding_failures.get(payment_key)
|
||||
error_reason = self.lnworker.final_onion_forwarding_failures.get(payment_key)
|
||||
if error_reason:
|
||||
self.logger.info(f'trampoline forwarding failure: {error_reason.code_name()}')
|
||||
raise error_reason
|
||||
|
||||
@@ -703,8 +703,8 @@ class LNWallet(LNWorker):
|
||||
for payment_hash in self.get_payments(status='inflight').keys():
|
||||
self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
|
||||
|
||||
self.trampoline_forwardings = set()
|
||||
self.trampoline_forwarding_failures = {} # todo: should be persisted
|
||||
self.final_onion_forwardings = set()
|
||||
self.final_onion_forwarding_failures = {} # todo: should be persisted
|
||||
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
|
||||
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]
|
||||
# payment_hash -> callback, timeout:
|
||||
@@ -1954,11 +1954,8 @@ class LNWallet(LNWorker):
|
||||
info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID)
|
||||
self.save_payment_info(info, write_to_disk=False)
|
||||
|
||||
def register_callback_for_hold_invoice(
|
||||
self, payment_hash: bytes, cb: Callable[[bytes], None], timeout: int,
|
||||
):
|
||||
expiry = int(time.time()) + timeout
|
||||
self.hold_invoice_callbacks[payment_hash] = cb, expiry
|
||||
def register_callback_for_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], None]):
|
||||
self.hold_invoice_callbacks[payment_hash] = cb
|
||||
|
||||
def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
|
||||
key = info.payment_hash.hex()
|
||||
@@ -2758,7 +2755,7 @@ class LNWallet(LNWorker):
|
||||
util.trigger_callback('channels_updated', self.wallet)
|
||||
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
|
||||
|
||||
def fail_trampoline_forwarding(self, payment_key):
|
||||
def fail_final_onion_forwarding(self, payment_key):
|
||||
""" use this to fail htlcs received for hold invoices"""
|
||||
e = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
|
||||
self.trampoline_forwarding_failures[payment_key] = e
|
||||
self.final_onion_forwarding_failures[payment_key] = e
|
||||
|
||||
@@ -247,7 +247,7 @@ class SwapManager(Logger):
|
||||
self.logger.info(f'found confirmed refund')
|
||||
payment_secret = self.lnworker.get_payment_secret(swap.payment_hash)
|
||||
payment_key = swap.payment_hash + payment_secret
|
||||
self.lnworker.fail_trampoline_forwarding(payment_key)
|
||||
self.lnworker.fail_final_onion_forwarding(payment_key)
|
||||
|
||||
if delta < 0:
|
||||
# too early for refund
|
||||
@@ -343,7 +343,7 @@ class SwapManager(Logger):
|
||||
)
|
||||
# add payment info to lnworker
|
||||
self.lnworker.add_payment_info_for_hold_invoice(payment_hash, main_amount_sat)
|
||||
self.lnworker.register_callback_for_hold_invoice(payment_hash, self.hold_invoice_callback, 60*60*24)
|
||||
self.lnworker.register_callback_for_hold_invoice(payment_hash, self.hold_invoice_callback)
|
||||
prepay_hash = self.lnworker.create_payment_info(amount_msat=prepay_amount_sat*1000)
|
||||
_, prepay_invoice = self.lnworker.get_bolt11_invoice(
|
||||
payment_hash=prepay_hash,
|
||||
|
||||
@@ -36,7 +36,7 @@ from electrum.lnmsg import encode_msg, decode_msg
|
||||
from electrum import lnmsg
|
||||
from electrum.logging import console_stderr_handler, Logger
|
||||
from electrum.lnworker import PaymentInfo, RECEIVED
|
||||
from electrum.lnonion import OnionFailureCode
|
||||
from electrum.lnonion import OnionFailureCode, OnionRoutingFailure
|
||||
from electrum.lnutil import UpdateAddHtlc
|
||||
from electrum.lnutil import LOCAL, REMOTE
|
||||
from electrum.invoices import PR_PAID, PR_UNPAID
|
||||
@@ -169,8 +169,8 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
self.sent_htlcs_q = defaultdict(asyncio.Queue)
|
||||
self.sent_htlcs_info = dict()
|
||||
self.sent_buckets = defaultdict(set)
|
||||
self.trampoline_forwardings = set()
|
||||
self.trampoline_forwarding_failures = {}
|
||||
self.final_onion_forwardings = set()
|
||||
self.final_onion_forwarding_failures = {}
|
||||
self.inflight_payments = set()
|
||||
self.preimages = {}
|
||||
self.stopping_soon = False
|
||||
@@ -749,6 +749,7 @@ class TestPeer(ElectrumTestCase):
|
||||
async def _test_simple_payment(
|
||||
self,
|
||||
test_trampoline: bool,
|
||||
test_failure:bool=False,
|
||||
test_hold_invoice=False,
|
||||
test_bundle=False,
|
||||
test_bundle_timeout=False
|
||||
@@ -765,12 +766,16 @@ class TestPeer(ElectrumTestCase):
|
||||
else:
|
||||
raise PaymentFailure()
|
||||
lnaddr, pay_req = self.prepare_invoice(w2)
|
||||
if test_hold_invoice:
|
||||
payment_hash = lnaddr.paymenthash
|
||||
payment_hash = lnaddr.paymenthash
|
||||
if test_failure or test_hold_invoice:
|
||||
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
|
||||
async def cb(payment_hash):
|
||||
w2.save_preimage(payment_hash, preimage)
|
||||
w2.register_callback_for_hold_invoice(payment_hash, cb, 60)
|
||||
if test_hold_invoice:
|
||||
async def cb(payment_hash):
|
||||
if not test_failure:
|
||||
w2.save_preimage(payment_hash, preimage)
|
||||
else:
|
||||
raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
|
||||
w2.register_callback_for_hold_invoice(payment_hash, cb)
|
||||
|
||||
if test_bundle:
|
||||
lnaddr2, pay_req2 = self.prepare_invoice(w2)
|
||||
@@ -799,11 +804,16 @@ class TestPeer(ElectrumTestCase):
|
||||
await f()
|
||||
|
||||
@needs_test_with_all_chacha20_implementations
|
||||
async def test_simple_payment(self):
|
||||
async def test_simple_payment_success(self):
|
||||
for test_trampoline in [False, True]:
|
||||
with self.assertRaises(PaymentDone):
|
||||
await self._test_simple_payment(test_trampoline=test_trampoline)
|
||||
|
||||
async def test_simple_payment_failure(self):
|
||||
for test_trampoline in [False, True]:
|
||||
with self.assertRaises(PaymentFailure):
|
||||
await self._test_simple_payment(test_trampoline=test_trampoline, test_failure=True)
|
||||
|
||||
async def test_payment_bundle(self):
|
||||
for test_trampoline in [False, True]:
|
||||
with self.assertRaises(PaymentDone):
|
||||
@@ -814,11 +824,16 @@ class TestPeer(ElectrumTestCase):
|
||||
with self.assertRaises(PaymentFailure):
|
||||
await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True, test_bundle_timeout=True)
|
||||
|
||||
async def test_simple_payment_with_hold_invoice(self):
|
||||
async def test_simple_payment_success_with_hold_invoice(self):
|
||||
for test_trampoline in [False, True]:
|
||||
with self.assertRaises(PaymentDone):
|
||||
await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True)
|
||||
|
||||
async def test_simple_payment_failure_with_hold_invoice(self):
|
||||
for test_trampoline in [False, True]:
|
||||
with self.assertRaises(PaymentFailure):
|
||||
await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True, test_failure=True)
|
||||
|
||||
@needs_test_with_all_chacha20_implementations
|
||||
async def test_payment_race(self):
|
||||
"""Alice and Bob pay each other simultaneously.
|
||||
|
||||
Reference in New Issue
Block a user