diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 26a166fcf..97813f5f7 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -40,8 +40,7 @@ from electrum import lnmsg from electrum.logging import console_stderr_handler, Logger from electrum.lnworker import PaymentInfo, RECEIVED from electrum.lnonion import OnionFailureCode, OnionRoutingFailure -from electrum.lnutil import UpdateAddHtlc -from electrum.lnutil import LOCAL, REMOTE +from electrum.lnutil import LOCAL, REMOTE, UpdateAddHtlc, RecvMPPResolution from electrum.invoices import PR_PAID, PR_UNPAID, Invoice, LN_EXPIRY_NEVER from electrum.interface import GracefulDisconnect from electrum.simple_config import SimpleConfig @@ -1414,6 +1413,105 @@ class TestPeerDirect(TestPeer): util.unregister_callback(on_htlc_fulfilled) util.unregister_callback(on_htlc_failed) + async def test_mpp_cleanup_after_expiry(self): + """ + 1. Alice sends two HTLCs to Bob, not reaching total_msat, and eventually they MPP_TIMEOUT + 2. Bob fails both HTLCs + 3. Alice then retries and sends HTLCs again to Bob, for the same RHASH, + this time reaching total_msat, and the payment succeeds + + Test that the sets are properly cleaned up after MPP_TIMEOUT + and the sender gets a second chance to pay the same invoice. + """ + async def run_test(test_trampoline: bool): + alice_channel, bob_channel = create_test_channels() + alice_peer, bob_peer, alice_wallet, bob_wallet, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + lnaddr1, pay_req1 = self.prepare_invoice(bob_wallet, amount_msat=10_000) + + if test_trampoline: + await self._activate_trampoline(alice_wallet) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=bob_wallet.node_keypair.pubkey), + } + + async def _test(): + route = (await alice_wallet.create_routes_from_invoice(amount_msat=10_000, decoded_invoice=lnaddr1))[0][0].route + assert len(bob_wallet.received_mpp_htlcs) == 0 + # now alice sends two small htlcs, so the set stays incomplete + alice_peer.pay( # htlc 1 + route=route, + chan=alice_channel, + amount_msat=lnaddr1.get_amount_msat() // 4, + total_msat=lnaddr1.get_amount_msat(), + payment_hash=lnaddr1.paymenthash, + min_final_cltv_delta=400, + payment_secret=lnaddr1.payment_secret, + ) + alice_peer.pay( # htlc 2 + route=route, + chan=alice_channel, + amount_msat=lnaddr1.get_amount_msat() // 4, + total_msat=lnaddr1.get_amount_msat(), + payment_hash=lnaddr1.paymenthash, + min_final_cltv_delta=400, + payment_secret=lnaddr1.payment_secret, + ) + await asyncio.sleep(bob_wallet.MPP_EXPIRY // 2) # give bob time to receive the htlc + bob_payment_key = bob_wallet._get_payment_key(lnaddr1.paymenthash).hex() + assert bob_wallet.received_mpp_htlcs[bob_payment_key].resolution == RecvMPPResolution.WAITING + assert len(bob_wallet.received_mpp_htlcs[bob_payment_key].htlcs) == 2 + # now wait until bob expires the mpp (set) + await asyncio.wait_for(alice_htlc_resolved.wait(), bob_wallet.MPP_EXPIRY * 3) # this can take some time, esp. on CI + # check that bob failed the htlc + assert nhtlc_success == 0 and nhtlc_failed == 2 + # check that bob deleted the mpp set as it should be expired and resolved now + assert bob_payment_key not in bob_wallet.received_mpp_htlcs + alice_wallet._paysessions.clear() + assert alice_wallet.get_preimage(lnaddr1.paymenthash) is None # bob didn't preimage + # now try to pay again, this time the full amount + result, log = await alice_wallet.pay_invoice(pay_req1) + assert result is True + assert alice_wallet.get_preimage(lnaddr1.paymenthash) is not None # bob revealed preimage + assert len(bob_wallet.received_mpp_htlcs) == 0 # bob should also clean up a successful set + raise SuccessfulTest() + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(alice_peer._message_loop()) + await group.spawn(alice_peer.htlc_switch()) + await group.spawn(bob_peer._message_loop()) + await group.spawn(bob_peer.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(_test()) + + alice_htlc_resolved = asyncio.Event() + nhtlc_success = 0 + nhtlc_failed = 0 + async def on_sender_htlc_fulfilled(*args): + alice_htlc_resolved.set() + alice_htlc_resolved.clear() + nonlocal nhtlc_success + nhtlc_success += 1 + async def on_sender_htlc_failed(*args): + alice_htlc_resolved.set() + alice_htlc_resolved.clear() + nonlocal nhtlc_failed + nhtlc_failed += 1 + util.register_callback(on_sender_htlc_fulfilled, ["htlc_fulfilled"]) + util.register_callback(on_sender_htlc_failed, ["htlc_failed"]) + + try: + with self.assertRaises(SuccessfulTest): + await f() + finally: + util.unregister_callback(on_sender_htlc_fulfilled) + util.unregister_callback(on_sender_htlc_failed) + + for use_trampoline in [True, False]: + self.logger.debug(f"test_mpp_cleanup_after_expiry: {use_trampoline=}") + await run_test(use_trampoline) + async def test_legacy_shutdown_low(self): await self._test_shutdown(alice_fee=100, bob_fee=150)