diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 9d9ddc699..db38087ba 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -104,11 +104,13 @@ class MockNetwork: class MockBlockchain: - - def height(self): + def __init__(self): # Let's return a non-zero, realistic height. # 0 might hide relative vs abs locktime confusion bugs. - return 600_000 + self._height = 600_000 + + def height(self): + return self._height def is_tip_stale(self): return False @@ -1807,6 +1809,83 @@ class TestPeerDirect(TestPeer): await f() self.assertTrue(isinstance(failing_task.exception().__cause__, lnmsg.UnexpectedEndOfStream)) + async def test_hold_invoice_set_doesnt_get_expired(self): + """ + Alice pays a hold invoice from Bob, Bob doesn't release preimage. Verify that Bob doesn't + expire the htlc set MIN_FINAL_CLTV_DELTA_ACCEPTED blocks before htlc.cltv_abs (as we would do with normal htlc sets). + The htlc set should only get failed if the user of the hold invoice callback explicitly removes the + callback (e.g. after refunding and failing a swap), otherwise it should get timed out onchain (force-close). + + This only tests hold invoice logic for hold invoices registered with `LNWallet.register_hold_invoice()`, + as used e.g. by submarine swaps. It doesn't cover the hold invoices created by the hold invoice CLI + which behave differently and use the persisted `LNWallet.dont_expire_htlcs` dict. + """ + async def run_test(test_trampoline): + alice_channel, bob_channel = create_test_channels() + alice_p, bob_p, alice_w, bob_w, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + lnaddr, pay_req = self.prepare_invoice(bob_w, min_final_cltv_delta=150) + del bob_w._preimages[pay_req.rhash] # del preimage so bob doesn't settle + payment_key = bob_w._get_payment_key(lnaddr.paymenthash).hex() + + cb_got_called = False + async def cb(_payment_hash): + self.logger.debug(f"hold invoice callback called. {bob_w.network.get_local_height()=}") + nonlocal cb_got_called + cb_got_called = True + + bob_w.register_hold_invoice(lnaddr.paymenthash, cb) + + async def check_mpp_state(): + async def wait_for_resolution(): + while True: + await asyncio.sleep(0.1) + if payment_key not in bob_w.received_mpp_htlcs: + continue + if not bob_w.received_mpp_htlcs[payment_key].resolution == RecvMPPResolution.SETTLING: + continue + return + await util.wait_for2(wait_for_resolution(), timeout=2) + assert cb_got_called + mpp_set = bob_w.received_mpp_htlcs[payment_key] + self.assertEqual(mpp_set.resolution, RecvMPPResolution.SETTLING, msg=mpp_set.resolution) + self.assertEqual(len(mpp_set.htlcs), 1, f"should get only one htlc: {mpp_set.htlcs=}") + left_to_expiry = next(iter(mpp_set.htlcs)).htlc.cltv_abs - bob_w.network.get_local_height() + # now mine up to one block after the expiry + bob_w.network._blockchain._height += left_to_expiry + 1 + await asyncio.sleep(0.2) + # bob still has the mpp set and it is not failed + # it should only get removed once the channel is redeemed + self.assertIn(bob_w.received_mpp_htlcs[payment_key].resolution, (RecvMPPResolution.COMPLETE, RecvMPPResolution.SETTLING)) + # now also check that the mpp set will get set failed if the hold invoice + # is being explicitly unregistered, and we don't have a preimage to settle it + bob_w.unregister_hold_invoice(lnaddr.paymenthash) + self.assertEqual(bob_w.received_mpp_htlcs[payment_key].resolution, RecvMPPResolution.FAILED) + raise SuccessfulTest() + + if test_trampoline: + await self._activate_trampoline(alice_w) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=bob_w.node_keypair.pubkey), + } + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(alice_p._message_loop()) + await group.spawn(alice_p.htlc_switch()) + await group.spawn(bob_p._message_loop()) + await group.spawn(bob_p.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(alice_w.pay_invoice(pay_req)) + await group.spawn(check_mpp_state()) + + with self.assertRaises(SuccessfulTest): + await f() + + for _test_trampoline in [False, True]: + await run_test(_test_trampoline) + class TestPeerForwarding(TestPeer):