diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 204ebd6f4..09cdbc24d 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -862,20 +862,24 @@ class TestPeerDirect(TestPeer): """Alice pays Bob a single HTLC via direct channel.""" alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + results = {} async def pay(lnaddr, pay_req): self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) result, log = await w1.pay_invoice(pay_req) if result is True: self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) - raise PaymentDone() + results[lnaddr] = PaymentDone() else: - raise PaymentFailure() + results[lnaddr] = PaymentFailure() lnaddr, pay_req = self.prepare_invoice(w2) + to_pay = [(lnaddr, pay_req)] self.prepare_recipient(w2, lnaddr.paymenthash, test_hold_invoice, test_failure) if test_bundle: lnaddr2, pay_req2 = self.prepare_invoice(w2) w2.bundle_payments([lnaddr.paymenthash, lnaddr2.paymenthash]) + if not test_bundle_timeout: + to_pay.append((lnaddr2, pay_req2)) if test_trampoline: await self._activate_trampoline(w1) @@ -893,9 +897,16 @@ class TestPeerDirect(TestPeer): await asyncio.sleep(0.01) invoice_features = lnaddr.get_features() self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) - await group.spawn(pay(lnaddr, pay_req)) - if test_bundle and not test_bundle_timeout: - await group.spawn(pay(lnaddr2, pay_req2)) + for lnaddr_to_pay, pay_req_to_pay in to_pay: + await group.spawn(pay(lnaddr_to_pay, pay_req_to_pay)) + elapsed = 0 + while len(results) < len(to_pay) and elapsed < 4: + await asyncio.sleep(0.05) # wait for all payments to finish/fail (or timeout) + elapsed += 0.05 + self.assertEqual(len(results), len(to_pay), msg="timeout") + # all payment results should be similar + self.assertEqual(len(set(type(res) for res in results.values())), 1, msg=results) + raise list(results.values())[0] await f() @@ -919,6 +930,11 @@ class TestPeerDirect(TestPeer): with self.assertRaises(PaymentFailure): await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True, test_bundle_timeout=True) + async def test_payment_bundle_with_hold_invoice(self): + for test_trampoline in [False, True]: + with self.assertRaises(PaymentDone): + await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True, test_hold_invoice=True) + async def test_simple_payment_success_with_hold_invoice(self): for test_trampoline in [False, True]: with self.assertRaises(PaymentDone):