test_lnpeer: refactor tests for hold invoices
This commit is contained in:
@@ -721,7 +721,7 @@ class TestPeer(ElectrumTestCase):
|
||||
with self.assertRaises(SuccessfulTest):
|
||||
await f()
|
||||
|
||||
async def _test_simple_payment(self, trampoline: bool):
|
||||
async def _test_simple_payment(self, trampoline: bool, test_hold_invoice=False, test_timeout=False):
|
||||
"""Alice pays Bob a single HTLC via direct channel."""
|
||||
alice_channel, bob_channel = create_test_channels()
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||
@@ -738,7 +738,17 @@ class TestPeer(ElectrumTestCase):
|
||||
raise PaymentDone()
|
||||
else:
|
||||
raise PaymentFailure()
|
||||
async def f(test_hold_invoice=False, test_timeout=False):
|
||||
lnaddr, pay_req = self.prepare_invoice(w2)
|
||||
if test_hold_invoice:
|
||||
payment_hash = lnaddr.paymenthash
|
||||
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
|
||||
def cb(payment_hash):
|
||||
if not test_timeout:
|
||||
w2.save_preimage(payment_hash, preimage)
|
||||
timeout = 1 if test_timeout else 60
|
||||
w2.register_callback_for_hold_invoice(payment_hash, cb, timeout)
|
||||
|
||||
async def f():
|
||||
if trampoline:
|
||||
await turn_on_trampoline_alice()
|
||||
async with OldTaskGroup() as group:
|
||||
@@ -747,15 +757,6 @@ class TestPeer(ElectrumTestCase):
|
||||
await group.spawn(p2._message_loop())
|
||||
await group.spawn(p2.htlc_switch())
|
||||
await asyncio.sleep(0.01)
|
||||
lnaddr, pay_req = self.prepare_invoice(w2)
|
||||
if test_hold_invoice:
|
||||
payment_hash = lnaddr.paymenthash
|
||||
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
|
||||
def cb(payment_hash):
|
||||
if not test_timeout:
|
||||
w2.save_preimage(payment_hash, preimage)
|
||||
timeout = 1 if test_timeout else 60
|
||||
w2.register_callback_for_hold_invoice(payment_hash, cb, timeout)
|
||||
invoice_features = lnaddr.get_features()
|
||||
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
|
||||
await group.spawn(pay(lnaddr, pay_req))
|
||||
@@ -763,20 +764,23 @@ class TestPeer(ElectrumTestCase):
|
||||
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
|
||||
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
|
||||
}
|
||||
with self.assertRaises(PaymentDone):
|
||||
await f(test_hold_invoice=False)
|
||||
with self.assertRaises(PaymentDone):
|
||||
await f(test_hold_invoice=True, test_timeout=False)
|
||||
with self.assertRaises(PaymentFailure):
|
||||
await f(test_hold_invoice=True, test_timeout=True)
|
||||
await f()
|
||||
|
||||
@needs_test_with_all_chacha20_implementations
|
||||
async def test_simple_payment(self):
|
||||
await self._test_simple_payment(trampoline=False)
|
||||
for trampoline in [False, True]:
|
||||
with self.assertRaises(PaymentDone):
|
||||
await self._test_simple_payment(trampoline=trampoline)
|
||||
|
||||
@needs_test_with_all_chacha20_implementations
|
||||
async def test_simple_payment_trampoline(self):
|
||||
await self._test_simple_payment(trampoline=True)
|
||||
async def test_simple_payment_with_hold_invoice(self):
|
||||
for trampoline in [False, True]:
|
||||
with self.assertRaises(PaymentDone):
|
||||
await self._test_simple_payment(trampoline=trampoline, test_hold_invoice=True)
|
||||
|
||||
async def test_simple_payment_with_hold_invoice_timing_out(self):
|
||||
for trampoline in [False, True]:
|
||||
with self.assertRaises(PaymentFailure):
|
||||
await self._test_simple_payment(trampoline=trampoline, test_hold_invoice=True, test_timeout=True)
|
||||
|
||||
@needs_test_with_all_chacha20_implementations
|
||||
async def test_payment_race(self):
|
||||
|
||||
Reference in New Issue
Block a user