diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index dee627c19..acc5a5ac9 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -2041,6 +2041,83 @@ class TestPeerDirect(TestPeer): with self.assertRaises(PaymentFailure if test_failure else PaymentDone): await run_test(test_trampoline, test_failure) + async def test_dont_expire_htlcs(self): + """ + Test that htlcs registered in LNWallet.dont_expire_htlcs don't get expired before the + specified expiry delta if their preimage isn't available. + Also test that htlcs registered in LNWallet.dont_expire_htlcs get settled right away if their + preimage is available. + """ + async def run_test(test_trampoline, test_expiry): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + if test_trampoline: + await self._activate_trampoline(w1) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), + } + + preimage = os.urandom(32) + lnaddr, pay_req = self.prepare_invoice(w2, payment_preimage=preimage, min_final_cltv_delta=144) + + # delete preimage, this would fail the htlcs if payment_hash wasn't in dont_expire_htlcs + del w2._preimages[pay_req.rhash] + # add payment_hash to dont_expire_htlcs so the htlcs are not getting failed + w2.dont_expire_htlcs[pay_req.rhash] = None if not test_expiry else 20 + + async def pay(lnaddr, pay_req): + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3) + if result is True: + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + return PaymentDone() + else: + self.assertIsNone(w2.get_preimage(lnaddr.paymenthash)) + return PaymentFailure() + + async def wait_for_htlcs(): + payment_key = w2._get_payment_key(lnaddr.paymenthash) + while payment_key.hex() not in w2.received_mpp_htlcs: + await asyncio.sleep(0.05) + if not test_expiry: + # the htlcs should never get expired if the dont_expire_htlcs value is None + w2.network.blockchain()._height += 1000 + await asyncio.sleep(0.25) # give w2 some time to do mistakes + self.assertEqual(w2.received_mpp_htlcs[payment_key.hex()].resolution, RecvMPPResolution.COMPLETE) + if test_expiry: + # we set an expiry delta of 20 blocks before expiry, htlc expiry should be +144 current height + # so adding some blocks should get the htlcs failed + w2.network.blockchain()._height += 50 + await asyncio.sleep(0.1) + # the htlcs should not get failed yet as 144-50 > 20 + self.assertEqual(w2.received_mpp_htlcs[payment_key.hex()].resolution, RecvMPPResolution.COMPLETE) + w2.network.blockchain()._height += 75 + return # the htlcs should get failed and pay should return PaymentFailure + + # saving the preimage should let the htlcs get fulfilled + w2.save_preimage(lnaddr.paymenthash, preimage) + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + invoice_features = lnaddr.get_features() + self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) + pay_task = await group.spawn(pay(lnaddr, pay_req)) + await util.wait_for2(wait_for_htlcs(), timeout=3) + raise await pay_task + + await f() + + for test_trampoline in [False, True]: + for test_expiry in [False, True]: + with self.assertRaises(PaymentFailure if test_expiry else PaymentDone): + await run_test(test_trampoline, test_expiry ) + class TestPeerForwarding(TestPeer):