diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 9c2731c60..085fc7576 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2306,8 +2306,6 @@ class Peer(Logger, EventListener): local_height = self.network.blockchain().height() payment_hash = htlc_set.get_payment_hash() assert payment_hash is not None, "Empty htlc set?" - self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None) - self.lnworker.dont_settle_htlcs.pop(payment_hash.hex(), None) # already failed for mpp_htlc in list(htlc_set.htlcs): chan = self.get_channel_by_id(mpp_htlc.channel_id) htlc_id = mpp_htlc.htlc.htlc_id @@ -3230,7 +3228,10 @@ class Peer(Logger, EventListener): # this was a forwarding set and it failed self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.FAILED) return error_bytes or failure_message, None, None - preimage = self.lnworker.get_preimage(mpp_set.get_payment_hash()) + payment_hash = mpp_set.get_payment_hash() + if payment_hash.hex() in self.lnworker.dont_settle_htlcs: + return None, None, None + preimage = self.lnworker.get_preimage(payment_hash) return None, preimage, None return None diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 3bc39c686..8b9f49def 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -33,7 +33,7 @@ from electrum.util import NetworkRetryManager, bfh, OldTaskGroup, EventListener, from electrum.lnpeer import Peer from electrum.lntransport import LNPeerAddr from electrum.crypto import privkey_to_pubkey -from electrum.lnutil import Keypair, PaymentFailure, LnFeatures, HTLCOwner, PaymentFeeBudget +from electrum.lnutil import Keypair, PaymentFailure, LnFeatures, HTLCOwner, PaymentFeeBudget, RECEIVED from electrum.lnchannel import ChannelState, PeerState, Channel from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent from electrum.channel_db import ChannelDB @@ -41,7 +41,7 @@ from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession from electrum.lnmsg import encode_msg, decode_msg from electrum import lnmsg from electrum.logging import console_stderr_handler, Logger -from electrum.lnworker import PaymentInfo, RECEIVED +from electrum.lnworker import PaymentInfo from electrum.lnonion import OnionFailureCode, OnionRoutingFailure, OnionHopsDataSingle, OnionPacket from electrum.lnutil import LOCAL, REMOTE, UpdateAddHtlc, RecvMPPResolution from electrum.invoices import PR_PAID, PR_UNPAID, Invoice, LN_EXPIRY_NEVER @@ -2045,76 +2045,6 @@ class TestPeerDirect(TestPeer): with self.assertRaises(SuccessfulTest): await f() - async def test_dont_settle_htlcs(self): - """ - Test that htlcs registered in LNWallet.dont_settle_htlcs don't get fulfilled if the preimage is available. - """ - async def run_test(test_trampoline, test_failure): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) - if test_trampoline: - await self._activate_trampoline(w1) - # declare bob as trampoline node - electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { - 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), - } - - preimage = os.urandom(32) - lnaddr, pay_req = self.prepare_invoice( - w2, - payment_preimage=preimage, - # use a higher min final cltv delta so we can mine some blocks later - min_final_cltv_delta=244, - ) - - # add payment_hash to dont_settle_htlcs so the htlcs are not getting settled - w2.dont_settle_htlcs[pay_req.rhash] = None - - async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) - result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3) - if result is True: - self.assertNotIn(pay_req.rhash, w2.dont_settle_htlcs) - self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) - return PaymentDone() - else: - self.assertIsNone(w2.get_preimage(lnaddr.paymenthash)) - return PaymentFailure() - - async def wait_for_htlcs(): - payment_key = w2._get_payment_key(lnaddr.paymenthash) - while payment_key.hex() not in w2.received_mpp_htlcs: - await asyncio.sleep(0.05) - w2.network.blockchain()._height += 25 # mine some blocks, shouldn't affect anything - if test_failure: - # delete preimage, this will fail htlcs even if registered in dont_settle_htlcs - del w2._preimages[pay_req.rhash] - return # pay() should fail now - await asyncio.sleep(0.25) # give w2 some time to do mistakes - self.assertEqual(w2.received_mpp_htlcs[payment_key.hex()].resolution, RecvMPPResolution.COMPLETE) - # remove the payment hash from dont_settle_htlcs so the htlcs can get fulfilled - del w2.dont_settle_htlcs[pay_req.rhash] - - async def f(): - async with OldTaskGroup() as group: - await group.spawn(p1._message_loop()) - await group.spawn(p1.htlc_switch()) - await group.spawn(p2._message_loop()) - await group.spawn(p2.htlc_switch()) - await asyncio.sleep(0.01) - invoice_features = lnaddr.get_features() - self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) - pay_task = await group.spawn(pay(lnaddr, pay_req)) - await util.wait_for2(wait_for_htlcs(), timeout=2) - raise await pay_task - - await f() - - for test_trampoline in [False, True]: - for test_failure in [False, True]: - with self.assertRaises(PaymentFailure if test_failure else PaymentDone): - await run_test(test_trampoline, test_failure) - async def test_dont_expire_htlcs(self): """ Test that htlcs registered in LNWallet.dont_expire_htlcs don't get expired before the @@ -2978,6 +2908,79 @@ class TestPeerForwarding(TestPeer): any('bob->carol' in msg and 'on_update_fail_malformed_htlc' in msg for msg in logs.output) ) + async def test_dont_settle_htlcs_receiver_and_forwarder(self): + """ + Test that the receiver and forwarder doesn't settle htlcs once they get the preimage if the payment + hash is in LNWallet.dont_settle_htlcs. E.g. the forwarder could be a just-in-time channel provider. + Alice -> Bob -> Carol. Carol and Bob shouldn't release the preimage. + """ + async def run_test(test_trampoline): + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['line_graph']) + peers = graph.peers.values() + + if test_trampoline: + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + graph.workers['bob'].name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers['bob'].node_keypair.pubkey), + } + await self._activate_trampoline(graph.workers['carol']) + await self._activate_trampoline(graph.workers['alice']) + + lnaddr, pay_req = self.prepare_invoice(graph.workers['carol'], include_routing_hints=True) + # test both receiver (carol) and forwarder (bob) + graph.workers['bob'].dont_settle_htlcs[lnaddr.paymenthash.hex()] = None + graph.workers['carol'].dont_settle_htlcs[lnaddr.paymenthash.hex()] = None + + payment_successful = asyncio.Event() + async def pay(): + self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) + result, log = await graph.workers['alice'].pay_invoice(pay_req) + self.assertEqual(PR_PAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) + self.assertTrue(result) + payment_successful.set() + + async def check_doesnt_settle(): + while not graph.workers['carol'].received_mpp_htlcs: + await asyncio.sleep(0.1) # wait until carol received the htlcs + + await asyncio.sleep(0.2) # give carol time to accidentally release the preimage + self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) + self.assertIsNone(graph.workers['bob'].get_preimage(lnaddr.paymenthash), "bob got preimage from carol") + # now allow carol to release the preimage to bob + del graph.workers['carol'].dont_settle_htlcs[lnaddr.paymenthash.hex()] + + # wait for carol to release the preimage to bob + while not graph.workers['bob'].get_preimage(lnaddr.paymenthash): + await asyncio.sleep(0.1) + + # give bob some time to settle the htlcs to alice (this would complete the payment) + await asyncio.sleep(0.2) + self.assertIsNone(graph.workers['alice'].get_preimage(lnaddr.paymenthash), "alice got preimage from bob") + self.assertFalse(payment_successful.is_set(), "bob released preimage") + + # now allow bob to settle the htlcs + del graph.workers['bob'].dont_settle_htlcs[lnaddr.paymenthash.hex()] + await payment_successful.wait() + raise PaymentDone() + + async def f(): + async with OldTaskGroup() as group: + for peer in peers: + await group.spawn(peer._message_loop()) + await group.spawn(peer.htlc_switch()) + for peer in peers: + await peer.initialized + + await group.spawn(pay()) + await group.spawn(check_doesnt_settle()) + # stop the taskgroup if anything takes too long + await group.spawn(asyncio.wait_for(asyncio.sleep(4), timeout=3)) + + await f() + + for trampoline in (False, True): + with self.assertRaises(PaymentDone): + await run_test(trampoline) + class TestPeerDirectAnchors(TestPeerDirect): TEST_ANCHOR_CHANNELS = True