diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 7cf778f94..26a166fcf 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -1336,6 +1336,84 @@ class TestPeerDirect(TestPeer): with self.assertRaises(SuccessfulTest): await f() + async def test_dont_settle_partial_mpp_trigger_with_invalid_cltv_htlc(self): + """Alice gets two htlcs as part of a mpp, one has a cltv too close to expiry and will get failed. + Test that the other htlc won't get settled if the mpp isn't complete anymore after failing the other htlc. + """ + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + async def pay(): + await util.wait_for2(p1.initialized, 1) + await util.wait_for2(p2.initialized, 1) + w2.features |= LnFeatures.BASIC_MPP_OPT + lnaddr1, _pay_req = self.prepare_invoice(w2, amount_msat=10_000, min_final_cltv_delta=144) + self.assertTrue(lnaddr1.get_features().supports(LnFeatures.BASIC_MPP_OPT)) + route = (await w1.create_routes_from_invoice(amount_msat=10_000, decoded_invoice=lnaddr1))[0][0].route + + # now p1 sends two htlcs, one is valid (1 msat), one is invalid (9_999 msat) + p1.pay( + route=route, + chan=alice_channel, + amount_msat=1, + total_msat=lnaddr1.get_amount_msat(), + payment_hash=lnaddr1.paymenthash, + # this htlc is valid and will get accepted, but it shouldn't get settled + min_final_cltv_delta=400, + payment_secret=lnaddr1.payment_secret, + ) + await asyncio.sleep(0.1) + assert w1.get_preimage(lnaddr1.paymenthash) is None + p1.pay( + route=route, + chan=alice_channel, + amount_msat=9_999, + total_msat=lnaddr1.get_amount_msat(), + payment_hash=lnaddr1.paymenthash, + # this htlc will get failed directly as the cltv is too close to expiry (< 144) + min_final_cltv_delta=1, + payment_secret=lnaddr1.payment_secret, + ) + + while nhtlc_success + nhtlc_failed < 2: + await htlc_resolved.wait() + # both htlcs of the mpp set should get failed and w2 shouldn't release the preimage + self.assertEqual(0, nhtlc_success, f"{nhtlc_success=} | {nhtlc_failed=}") + self.assertEqual(2, nhtlc_failed, f"{nhtlc_success=} | {nhtlc_failed=}") + assert w1.get_preimage(lnaddr1.paymenthash) is None, "w1 shouldn't get the preimage" + raise SuccessfulTest() + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(pay()) + + htlc_resolved = asyncio.Event() + nhtlc_success = 0 + nhtlc_failed = 0 + async def on_htlc_fulfilled(*args): + htlc_resolved.set() + htlc_resolved.clear() + nonlocal nhtlc_success + nhtlc_success += 1 + async def on_htlc_failed(*args): + htlc_resolved.set() + htlc_resolved.clear() + nonlocal nhtlc_failed + nhtlc_failed += 1 + util.register_callback(on_htlc_fulfilled, ["htlc_fulfilled"]) + util.register_callback(on_htlc_failed, ["htlc_failed"]) + + try: + with self.assertRaises(SuccessfulTest): + await f() + finally: + util.unregister_callback(on_htlc_fulfilled) + util.unregister_callback(on_htlc_failed) + async def test_legacy_shutdown_low(self): await self._test_shutdown(alice_fee=100, bob_fee=150)