test_lnpeer: refactor tests for hold invoices
This commit is contained in:
@@ -721,7 +721,7 @@ class TestPeer(ElectrumTestCase):
|
|||||||
with self.assertRaises(SuccessfulTest):
|
with self.assertRaises(SuccessfulTest):
|
||||||
await f()
|
await f()
|
||||||
|
|
||||||
async def _test_simple_payment(self, trampoline: bool):
|
async def _test_simple_payment(self, trampoline: bool, test_hold_invoice=False, test_timeout=False):
|
||||||
"""Alice pays Bob a single HTLC via direct channel."""
|
"""Alice pays Bob a single HTLC via direct channel."""
|
||||||
alice_channel, bob_channel = create_test_channels()
|
alice_channel, bob_channel = create_test_channels()
|
||||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||||
@@ -738,7 +738,17 @@ class TestPeer(ElectrumTestCase):
|
|||||||
raise PaymentDone()
|
raise PaymentDone()
|
||||||
else:
|
else:
|
||||||
raise PaymentFailure()
|
raise PaymentFailure()
|
||||||
async def f(test_hold_invoice=False, test_timeout=False):
|
lnaddr, pay_req = self.prepare_invoice(w2)
|
||||||
|
if test_hold_invoice:
|
||||||
|
payment_hash = lnaddr.paymenthash
|
||||||
|
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
|
||||||
|
def cb(payment_hash):
|
||||||
|
if not test_timeout:
|
||||||
|
w2.save_preimage(payment_hash, preimage)
|
||||||
|
timeout = 1 if test_timeout else 60
|
||||||
|
w2.register_callback_for_hold_invoice(payment_hash, cb, timeout)
|
||||||
|
|
||||||
|
async def f():
|
||||||
if trampoline:
|
if trampoline:
|
||||||
await turn_on_trampoline_alice()
|
await turn_on_trampoline_alice()
|
||||||
async with OldTaskGroup() as group:
|
async with OldTaskGroup() as group:
|
||||||
@@ -747,15 +757,6 @@ class TestPeer(ElectrumTestCase):
|
|||||||
await group.spawn(p2._message_loop())
|
await group.spawn(p2._message_loop())
|
||||||
await group.spawn(p2.htlc_switch())
|
await group.spawn(p2.htlc_switch())
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
lnaddr, pay_req = self.prepare_invoice(w2)
|
|
||||||
if test_hold_invoice:
|
|
||||||
payment_hash = lnaddr.paymenthash
|
|
||||||
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
|
|
||||||
def cb(payment_hash):
|
|
||||||
if not test_timeout:
|
|
||||||
w2.save_preimage(payment_hash, preimage)
|
|
||||||
timeout = 1 if test_timeout else 60
|
|
||||||
w2.register_callback_for_hold_invoice(payment_hash, cb, timeout)
|
|
||||||
invoice_features = lnaddr.get_features()
|
invoice_features = lnaddr.get_features()
|
||||||
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
|
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
|
||||||
await group.spawn(pay(lnaddr, pay_req))
|
await group.spawn(pay(lnaddr, pay_req))
|
||||||
@@ -763,20 +764,23 @@ class TestPeer(ElectrumTestCase):
|
|||||||
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
|
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
|
||||||
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
|
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
|
||||||
}
|
}
|
||||||
with self.assertRaises(PaymentDone):
|
await f()
|
||||||
await f(test_hold_invoice=False)
|
|
||||||
with self.assertRaises(PaymentDone):
|
|
||||||
await f(test_hold_invoice=True, test_timeout=False)
|
|
||||||
with self.assertRaises(PaymentFailure):
|
|
||||||
await f(test_hold_invoice=True, test_timeout=True)
|
|
||||||
|
|
||||||
@needs_test_with_all_chacha20_implementations
|
@needs_test_with_all_chacha20_implementations
|
||||||
async def test_simple_payment(self):
|
async def test_simple_payment(self):
|
||||||
await self._test_simple_payment(trampoline=False)
|
for trampoline in [False, True]:
|
||||||
|
with self.assertRaises(PaymentDone):
|
||||||
|
await self._test_simple_payment(trampoline=trampoline)
|
||||||
|
|
||||||
@needs_test_with_all_chacha20_implementations
|
async def test_simple_payment_with_hold_invoice(self):
|
||||||
async def test_simple_payment_trampoline(self):
|
for trampoline in [False, True]:
|
||||||
await self._test_simple_payment(trampoline=True)
|
with self.assertRaises(PaymentDone):
|
||||||
|
await self._test_simple_payment(trampoline=trampoline, test_hold_invoice=True)
|
||||||
|
|
||||||
|
async def test_simple_payment_with_hold_invoice_timing_out(self):
|
||||||
|
for trampoline in [False, True]:
|
||||||
|
with self.assertRaises(PaymentFailure):
|
||||||
|
await self._test_simple_payment(trampoline=trampoline, test_hold_invoice=True, test_timeout=True)
|
||||||
|
|
||||||
@needs_test_with_all_chacha20_implementations
|
@needs_test_with_all_chacha20_implementations
|
||||||
async def test_payment_race(self):
|
async def test_payment_race(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user