1
0

tests: test_lnpeer: test_hold_invoice_set_doesnt_get_exp

Add test `test_hold_invoice_set_doesnt_get_expired` to test_lnpeer to
ensure a mpp set on which a hold invoice callback doesn't get expired
automatically if the cltv_abs falls below MIN_FINAL_CLTV_DELTA_ACCEPTED
as these sets should only get failed if the htlcs are safe to fail by
the target of the hold invoice callback (e.g. swap got refunded
successfully).
This commit is contained in:
f321x
2025-10-08 12:59:48 +02:00
parent bb828097b3
commit f56b13b610

View File

@@ -104,11 +104,13 @@ class MockNetwork:
class MockBlockchain:
def height(self):
def __init__(self):
# Let's return a non-zero, realistic height.
# 0 might hide relative vs abs locktime confusion bugs.
return 600_000
self._height = 600_000
def height(self):
return self._height
def is_tip_stale(self):
return False
@@ -1807,6 +1809,83 @@ class TestPeerDirect(TestPeer):
await f()
self.assertTrue(isinstance(failing_task.exception().__cause__, lnmsg.UnexpectedEndOfStream))
async def test_hold_invoice_set_doesnt_get_expired(self):
"""
Alice pays a hold invoice from Bob, Bob doesn't release preimage. Verify that Bob doesn't
expire the htlc set MIN_FINAL_CLTV_DELTA_ACCEPTED blocks before htlc.cltv_abs (as we would do with normal htlc sets).
The htlc set should only get failed if the user of the hold invoice callback explicitly removes the
callback (e.g. after refunding and failing a swap), otherwise it should get timed out onchain (force-close).
This only tests hold invoice logic for hold invoices registered with `LNWallet.register_hold_invoice()`,
as used e.g. by submarine swaps. It doesn't cover the hold invoices created by the hold invoice CLI
which behave differently and use the persisted `LNWallet.dont_expire_htlcs` dict.
"""
async def run_test(test_trampoline):
alice_channel, bob_channel = create_test_channels()
alice_p, bob_p, alice_w, bob_w, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
lnaddr, pay_req = self.prepare_invoice(bob_w, min_final_cltv_delta=150)
del bob_w._preimages[pay_req.rhash] # del preimage so bob doesn't settle
payment_key = bob_w._get_payment_key(lnaddr.paymenthash).hex()
cb_got_called = False
async def cb(_payment_hash):
self.logger.debug(f"hold invoice callback called. {bob_w.network.get_local_height()=}")
nonlocal cb_got_called
cb_got_called = True
bob_w.register_hold_invoice(lnaddr.paymenthash, cb)
async def check_mpp_state():
async def wait_for_resolution():
while True:
await asyncio.sleep(0.1)
if payment_key not in bob_w.received_mpp_htlcs:
continue
if not bob_w.received_mpp_htlcs[payment_key].resolution == RecvMPPResolution.SETTLING:
continue
return
await util.wait_for2(wait_for_resolution(), timeout=2)
assert cb_got_called
mpp_set = bob_w.received_mpp_htlcs[payment_key]
self.assertEqual(mpp_set.resolution, RecvMPPResolution.SETTLING, msg=mpp_set.resolution)
self.assertEqual(len(mpp_set.htlcs), 1, f"should get only one htlc: {mpp_set.htlcs=}")
left_to_expiry = next(iter(mpp_set.htlcs)).htlc.cltv_abs - bob_w.network.get_local_height()
# now mine up to one block after the expiry
bob_w.network._blockchain._height += left_to_expiry + 1
await asyncio.sleep(0.2)
# bob still has the mpp set and it is not failed
# it should only get removed once the channel is redeemed
self.assertIn(bob_w.received_mpp_htlcs[payment_key].resolution, (RecvMPPResolution.COMPLETE, RecvMPPResolution.SETTLING))
# now also check that the mpp set will get set failed if the hold invoice
# is being explicitly unregistered, and we don't have a preimage to settle it
bob_w.unregister_hold_invoice(lnaddr.paymenthash)
self.assertEqual(bob_w.received_mpp_htlcs[payment_key].resolution, RecvMPPResolution.FAILED)
raise SuccessfulTest()
if test_trampoline:
await self._activate_trampoline(alice_w)
# declare bob as trampoline node
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=bob_w.node_keypair.pubkey),
}
async def f():
async with OldTaskGroup() as group:
await group.spawn(alice_p._message_loop())
await group.spawn(alice_p.htlc_switch())
await group.spawn(bob_p._message_loop())
await group.spawn(bob_p.htlc_switch())
await asyncio.sleep(0.01)
await group.spawn(alice_w.pay_invoice(pay_req))
await group.spawn(check_mpp_state())
with self.assertRaises(SuccessfulTest):
await f()
for _test_trampoline in [False, True]:
await run_test(_test_trampoline)
class TestPeerForwarding(TestPeer):