1
0

tests: lnpeer: add test_reject_mpp_for_non_mpp_invoice

This commit is contained in:
f321x
2025-12-09 17:49:25 +01:00
parent 183d426e93
commit 7c01d9db75

View File

@@ -1071,6 +1071,46 @@ class TestPeerDirect(TestPeer):
for _test_trampoline in [False, True]: for _test_trampoline in [False, True]:
await run_test(_test_trampoline) await run_test(_test_trampoline)
async def test_reject_mpp_for_non_mpp_invoice(self):
"""Test that we reject a payment if it is mpp and we didn't signal support for mpp in the invoice"""
async def run_test(test_trampoline):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
w1.config.TEST_FORCE_MPP = True # force alice to send mpp
if test_trampoline:
await self._activate_trampoline(w1)
await self._activate_trampoline(w2)
# declare bob as trampoline node
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
}
lnaddr, pay_req = self.prepare_invoice(w2)
self.assertFalse(lnaddr.get_features().supports(LnFeatures.BASIC_MPP_OPT))
self.assertFalse(lnaddr.get_features().supports(LnFeatures.BASIC_MPP_REQ))
async def try_pay_invoice_with_mpp(pay_req: Invoice, w1=w1):
result, log = await w1.pay_invoice(pay_req)
if not result:
raise PaymentFailure()
raise PaymentDone()
async def f():
async with OldTaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p1.htlc_switch())
await group.spawn(p2._message_loop())
await group.spawn(p2.htlc_switch())
await asyncio.sleep(0.01)
await group.spawn(try_pay_invoice_with_mpp(pay_req))
with self.assertRaises(PaymentFailure):
await f()
for _test_trampoline in [False, True]:
await run_test(_test_trampoline)
async def test_reject_multiple_payments_of_same_invoice(self): async def test_reject_multiple_payments_of_same_invoice(self):
"""Tests that new htlcs paying an invoice that has already been paid will get rejected.""" """Tests that new htlcs paying an invoice that has already been paid will get rejected."""
async def run_test(test_trampoline): async def run_test(test_trampoline):