diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 2a86f5d7a..e5f98ec12 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -11,6 +11,7 @@ import concurrent from concurrent import futures from unittest import mock from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence +import time from aiorpcx import timeout_after, TaskTimeout from electrum_ecc import ECPrivkey @@ -559,6 +560,7 @@ class TestPeer(ElectrumTestCase): payment_hash: bytes = None, invoice_features: LnFeatures = None, min_final_cltv_delta: int = None, + expiry: int = None, ) -> Tuple[LnAddr, Invoice]: amount_btc = amount_msat/Decimal(COIN*1000) if payment_preimage is None and not payment_hash: @@ -586,7 +588,7 @@ class TestPeer(ElectrumTestCase): direction=RECEIVED, status=PR_UNPAID, min_final_cltv_delta=min_final_cltv_delta, - expiry_delay=LN_EXPIRY_NEVER, + expiry_delay=expiry or LN_EXPIRY_NEVER, ) w2.save_payment_info(info) lnaddr1 = LnAddr( @@ -596,6 +598,7 @@ class TestPeer(ElectrumTestCase): ('c', min_final_cltv_delta), ('d', 'coffee'), ('9', invoice_features), + ('x', expiry or 3600), ] + routing_hints, payment_secret=payment_secret, ) @@ -1000,6 +1003,49 @@ class TestPeerDirect(TestPeer): for _test_trampoline in [False, True]: await run_test(_test_trampoline) + async def test_reject_payment_for_expired_invoice(self): + """Tests that new htlcs paying an invoice that has already been expired will get rejected.""" + async def run_test(test_trampoline): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + # create lightning invoice in the past, so it is expired + with mock.patch('time.time', return_value=int(time.time()) - 10000): + lnaddr, _pay_req = self.prepare_invoice(w2, expiry=3600) + b11 = lnencode(lnaddr, w2.node_keypair.privkey) + pay_req = Invoice.from_bech32(b11) + + async def try_pay_expired_invoice(pay_req: Invoice, w1=w1): + assert pay_req.has_expired() + assert lnaddr.is_expired() + with mock.patch.object(w1, "_check_bolt11_invoice", return_value=lnaddr): + result, log = await w1.pay_invoice(pay_req) + if not result: + raise PaymentFailure() + raise PaymentDone() + + if test_trampoline: + await self._activate_trampoline(w1) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), + } + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(try_pay_expired_invoice(pay_req)) + + with self.assertRaises(PaymentFailure): + await f() + + for _test_trampoline in [False, True]: + await run_test(_test_trampoline) + async def test_payment_race(self): """Alice and Bob pay each other simultaneously. They both send 'update_add_htlc' and receive each other's update