lnworker: bundled payments
- htlcs of bundled payments must arrive in the same MPP_TIMEOUT window, or they will be failed - add correspoding tests
This commit is contained in:
@@ -677,6 +677,8 @@ class LNWallet(LNWorker):
|
||||
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
|
||||
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]
|
||||
self.hold_invoice_callbacks = {} # payment_hash -> callback, timeout
|
||||
self.payment_bundles = [] # lists of hashes. todo:persist
|
||||
|
||||
|
||||
def has_deterministic_node_id(self) -> bool:
|
||||
return bool(self.db.get('lightning_xprv'))
|
||||
@@ -1862,6 +1864,14 @@ class LNWallet(LNWorker):
|
||||
self.wallet.save_db()
|
||||
return payment_hash
|
||||
|
||||
def bundle_payments(self, hash_list):
|
||||
self.payment_bundles.append(hash_list)
|
||||
|
||||
def get_payment_bundle(self, payment_hash):
|
||||
for hash_list in self.payment_bundles:
|
||||
if payment_hash in hash_list:
|
||||
return hash_list
|
||||
|
||||
def save_preimage(self, payment_hash: bytes, preimage: bytes, *, write_to_disk: bool = True):
|
||||
assert sha256(preimage) == payment_hash
|
||||
self.preimages[payment_hash.hex()] = preimage.hex()
|
||||
@@ -1901,45 +1911,87 @@ class LNWallet(LNWorker):
|
||||
""" return MPP status: True (accepted), False (expired) or None (waiting)
|
||||
"""
|
||||
payment_hash = htlc.payment_hash
|
||||
preimage = self.get_preimage(payment_hash)
|
||||
callback = self.hold_invoice_callbacks.get(payment_hash)
|
||||
if not preimage and callback:
|
||||
cb, timeout = callback
|
||||
if int(time.time()) < timeout:
|
||||
cb(payment_hash)
|
||||
return None
|
||||
else:
|
||||
return False
|
||||
|
||||
amt_to_forward = htlc.amount_msat # check this
|
||||
if amt_to_forward >= expected_msat:
|
||||
# not multi-part
|
||||
return True
|
||||
self.update_mpp_with_received_htlc(payment_secret, short_channel_id, htlc, expected_msat)
|
||||
is_expired, is_accepted = self.get_mpp_status(payment_secret)
|
||||
if not is_accepted and not is_expired:
|
||||
bundle = self.get_payment_bundle(payment_hash)
|
||||
payment_hashes = bundle or [payment_hash]
|
||||
payment_secrets = [self.get_payment_secret(h) for h in bundle] if bundle else [payment_secret]
|
||||
first_timestamp = min([self.get_first_timestamp_of_mpp(x) for x in payment_secrets])
|
||||
if self.get_payment_status(payment_hash) == PR_PAID:
|
||||
is_accepted = True
|
||||
elif self.stopping_soon:
|
||||
is_expired = True # try to time out pending HTLCs before shutting down
|
||||
elif time.time() - first_timestamp > self.MPP_EXPIRY:
|
||||
is_expired = True
|
||||
elif all([self.is_mpp_amount_reached(x) for x in payment_secrets]):
|
||||
preimage = self.get_preimage(payment_hash)
|
||||
hold_invoice_callback = self.hold_invoice_callbacks.get(payment_hash)
|
||||
if not preimage and hold_invoice_callback:
|
||||
# for hold invoices, trigger callback
|
||||
cb, timeout = hold_invoice_callback
|
||||
if int(time.time()) < timeout:
|
||||
cb(payment_hash)
|
||||
else:
|
||||
is_expired = True
|
||||
elif bundle is not None:
|
||||
is_accepted = all([bool(self.get_preimage(x)) for x in bundle])
|
||||
else:
|
||||
# trampoline forwarding needs this to return True
|
||||
is_accepted = True
|
||||
|
||||
is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set()))
|
||||
if self.get_payment_status(payment_hash) == PR_PAID:
|
||||
# payment_status is persisted
|
||||
is_accepted = True
|
||||
is_expired = False
|
||||
# set status for the bundle
|
||||
if is_expired or is_accepted:
|
||||
for x in payment_secrets:
|
||||
if x in self.received_mpp_htlcs:
|
||||
self.set_mpp_status(x, is_expired, is_accepted)
|
||||
|
||||
self.maybe_cleanup_mpp_status(payment_secret, short_channel_id, htlc)
|
||||
return True if is_accepted else (False if is_expired else None)
|
||||
|
||||
def update_mpp_with_received_htlc(self, payment_secret, short_channel_id, htlc, expected_msat):
|
||||
# add new htlc to set
|
||||
is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, expected_msat, set()))
|
||||
assert expected_msat == _expected_msat
|
||||
key = (short_channel_id, htlc)
|
||||
if key not in htlc_set:
|
||||
htlc_set.add(key)
|
||||
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set
|
||||
|
||||
def get_mpp_status(self, payment_secret):
|
||||
is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret]
|
||||
return is_expired, is_accepted
|
||||
|
||||
def set_mpp_status(self, payment_secret, is_expired, is_accepted):
|
||||
_is_expired, _is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret]
|
||||
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set
|
||||
|
||||
def is_mpp_amount_reached(self, payment_secret):
|
||||
mpp = self.received_mpp_htlcs.get(payment_secret)
|
||||
if not mpp:
|
||||
return False
|
||||
is_expired, is_accepted, _expected_msat, htlc_set = mpp
|
||||
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
|
||||
return total >= _expected_msat
|
||||
|
||||
def get_first_timestamp_of_mpp(self, payment_secret):
|
||||
mpp = self.received_mpp_htlcs.get(payment_secret)
|
||||
if not mpp:
|
||||
return int(time.time())
|
||||
is_expired, is_accepted, _expected_msat, htlc_set = mpp
|
||||
return min([_htlc.timestamp for scid, _htlc in htlc_set])
|
||||
|
||||
def maybe_cleanup_mpp_status(self, payment_secret, short_channel_id, htlc):
|
||||
is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret]
|
||||
if not is_accepted and not is_expired:
|
||||
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
|
||||
first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set])
|
||||
if self.stopping_soon:
|
||||
is_expired = True # try to time out pending HTLCs before shutting down
|
||||
elif time.time() - first_timestamp > self.MPP_EXPIRY:
|
||||
is_expired = True
|
||||
elif total == expected_msat:
|
||||
is_accepted = True
|
||||
if is_accepted or is_expired:
|
||||
htlc_set.remove(key)
|
||||
return
|
||||
key = (short_channel_id, htlc)
|
||||
htlc_set.remove(key)
|
||||
if len(htlc_set) > 0:
|
||||
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, htlc_set
|
||||
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set
|
||||
elif payment_secret in self.received_mpp_htlcs:
|
||||
self.received_mpp_htlcs.pop(payment_secret)
|
||||
return True if is_accepted else (False if is_expired else None)
|
||||
|
||||
def get_payment_status(self, payment_hash: bytes) -> int:
|
||||
info = self.get_payment_info(payment_hash)
|
||||
|
||||
@@ -174,6 +174,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
self.stopping_soon = False
|
||||
self.downstream_htlc_to_upstream_peer_map = {}
|
||||
self.hold_invoice_callbacks = {}
|
||||
self.payment_bundles = [] # lists of hashes. todo:persist
|
||||
|
||||
self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}")
|
||||
|
||||
@@ -279,6 +280,16 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
register_callback_for_hold_invoice = LNWallet.register_callback_for_hold_invoice
|
||||
add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice
|
||||
|
||||
update_mpp_with_received_htlc = LNWallet.update_mpp_with_received_htlc
|
||||
get_mpp_status = LNWallet.get_mpp_status
|
||||
set_mpp_status = LNWallet.set_mpp_status
|
||||
is_mpp_amount_reached = LNWallet.is_mpp_amount_reached
|
||||
get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp
|
||||
maybe_cleanup_mpp_status = LNWallet.maybe_cleanup_mpp_status
|
||||
bundle_payments = LNWallet.bundle_payments
|
||||
get_payment_bundle = LNWallet.get_payment_bundle
|
||||
|
||||
|
||||
class MockTransport:
|
||||
def __init__(self, name):
|
||||
self.queue = asyncio.Queue() # incoming messages
|
||||
@@ -727,7 +738,14 @@ class TestPeer(ElectrumTestCase):
|
||||
await w.network.channel_db.stopped_event.wait()
|
||||
w.network.channel_db = None
|
||||
|
||||
async def _test_simple_payment(self, trampoline: bool, test_hold_invoice=False, test_timeout=False):
|
||||
async def _test_simple_payment(
|
||||
self,
|
||||
test_trampoline: bool,
|
||||
test_hold_invoice=False,
|
||||
test_hold_timeout=False,
|
||||
test_bundle=False,
|
||||
test_bundle_timeout=False
|
||||
):
|
||||
"""Alice pays Bob a single HTLC via direct channel."""
|
||||
alice_channel, bob_channel = create_test_channels()
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||
@@ -746,12 +764,21 @@ class TestPeer(ElectrumTestCase):
|
||||
def cb(payment_hash):
|
||||
if not test_timeout:
|
||||
w2.save_preimage(payment_hash, preimage)
|
||||
timeout = 1 if test_timeout else 60
|
||||
timeout = 1 if test_hold_timeout else 60
|
||||
w2.register_callback_for_hold_invoice(payment_hash, cb, timeout)
|
||||
|
||||
if test_bundle:
|
||||
lnaddr2, pay_req2 = self.prepare_invoice(w2)
|
||||
w2.bundle_payments([lnaddr.paymenthash, lnaddr2.paymenthash])
|
||||
|
||||
if test_trampoline:
|
||||
await self._activate_trampoline(w1)
|
||||
# declare bob as trampoline node
|
||||
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
|
||||
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
|
||||
}
|
||||
|
||||
async def f():
|
||||
if trampoline:
|
||||
await self._activate_trampoline(w1)
|
||||
async with OldTaskGroup() as group:
|
||||
await group.spawn(p1._message_loop())
|
||||
await group.spawn(p1.htlc_switch())
|
||||
@@ -761,22 +788,31 @@ class TestPeer(ElectrumTestCase):
|
||||
invoice_features = lnaddr.get_features()
|
||||
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
|
||||
await group.spawn(pay(lnaddr, pay_req))
|
||||
# declare bob as trampoline node
|
||||
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
|
||||
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
|
||||
}
|
||||
if test_bundle and not test_bundle_timeout:
|
||||
await group.spawn(pay(lnaddr2, pay_req2))
|
||||
|
||||
await f()
|
||||
|
||||
@needs_test_with_all_chacha20_implementations
|
||||
async def test_simple_payment(self):
|
||||
for trampoline in [False, True]:
|
||||
for test_trampoline in [False, True]:
|
||||
with self.assertRaises(PaymentDone):
|
||||
await self._test_simple_payment(trampoline=trampoline)
|
||||
await self._test_simple_payment(test_trampoline=test_trampoline)
|
||||
|
||||
async def test_payment_bundle(self):
|
||||
for test_trampoline in [False, True]:
|
||||
with self.assertRaises(PaymentDone):
|
||||
await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True)
|
||||
|
||||
async def test_payment_bundle_timeout(self):
|
||||
for test_trampoline in [False, True]:
|
||||
with self.assertRaises(PaymentFailure):
|
||||
await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True, test_bundle_timeout=True)
|
||||
|
||||
async def test_simple_payment_with_hold_invoice(self):
|
||||
for trampoline in [False, True]:
|
||||
for test_trampoline in [False, True]:
|
||||
with self.assertRaises(PaymentDone):
|
||||
await self._test_simple_payment(trampoline=trampoline, test_hold_invoice=True)
|
||||
await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True)
|
||||
|
||||
async def test_simple_payment_with_hold_invoice_timing_out(self):
|
||||
for trampoline in [False, True]:
|
||||
|
||||
Reference in New Issue
Block a user