tests: add test for prev
This commit is contained in:
@@ -578,6 +578,7 @@ class LNWallet(LNWorker):
|
||||
|
||||
lnwatcher: Optional['LNWalletWatcher']
|
||||
MPP_EXPIRY = 120
|
||||
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 # seconds
|
||||
|
||||
def __init__(self, wallet: 'Abstract_Wallet', xprv):
|
||||
self.wallet = wallet
|
||||
@@ -713,11 +714,12 @@ class LNWallet(LNWorker):
|
||||
self.stopping_soon = True
|
||||
if self.listen_server: # stop accepting new peers
|
||||
self.listen_server.close()
|
||||
async with ignore_after(3):
|
||||
async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
|
||||
await self.wait_for_received_pending_htlcs_to_get_removed()
|
||||
await super().stop()
|
||||
await self.lnwatcher.stop()
|
||||
self.lnwatcher = None
|
||||
await LNWorker.stop(self)
|
||||
if self.lnwatcher:
|
||||
await self.lnwatcher.stop()
|
||||
self.lnwatcher = None
|
||||
|
||||
async def wait_for_received_pending_htlcs_to_get_removed(self):
|
||||
assert self.stopping_soon is True
|
||||
|
||||
@@ -10,7 +10,7 @@ from concurrent import futures
|
||||
import unittest
|
||||
from typing import Iterable, NamedTuple, Tuple, List
|
||||
|
||||
from aiorpcx import TaskGroup
|
||||
from aiorpcx import TaskGroup, timeout_after, TaskTimeout
|
||||
|
||||
from electrum import bitcoin
|
||||
from electrum import constants
|
||||
@@ -113,7 +113,8 @@ class MockWallet:
|
||||
|
||||
|
||||
class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
|
||||
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
|
||||
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0
|
||||
|
||||
def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name):
|
||||
self.name = name
|
||||
@@ -121,6 +122,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
|
||||
self.node_keypair = local_keypair
|
||||
self.network = MockNetwork(tx_queue)
|
||||
self.taskgroup = TaskGroup()
|
||||
self.lnwatcher = None
|
||||
self.listen_server = None
|
||||
self._channels = {chan.channel_id: chan for chan in chans}
|
||||
self.payments = {}
|
||||
self.logs = defaultdict(list)
|
||||
@@ -184,6 +188,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
return self.name
|
||||
|
||||
async def stop(self):
|
||||
await LNWallet.stop(self)
|
||||
if self.channel_db:
|
||||
self.channel_db.stop()
|
||||
await self.channel_db.stopped_event.wait()
|
||||
@@ -216,6 +221,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
_calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice
|
||||
handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
|
||||
is_trampoline_peer = LNWallet.is_trampoline_peer
|
||||
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
|
||||
on_proxy_changed = LNWallet.on_proxy_changed
|
||||
|
||||
|
||||
class MockTransport:
|
||||
@@ -291,13 +298,9 @@ class SquareGraph(NamedTuple):
|
||||
def all_lnworkers(self) -> Iterable[MockLNWallet]:
|
||||
return self.w_a, self.w_b, self.w_c, self.w_d
|
||||
|
||||
async def stop_and_cleanup(self):
|
||||
async with TaskGroup() as group:
|
||||
for lnworker in self.all_lnworkers():
|
||||
await group.spawn(lnworker.stop())
|
||||
|
||||
|
||||
class PaymentDone(Exception): pass
|
||||
class TestSuccess(Exception): pass
|
||||
|
||||
|
||||
class TestPeer(ElectrumTestCase):
|
||||
@@ -837,6 +840,50 @@ class TestPeer(ElectrumTestCase):
|
||||
graph = self.prepare_chans_and_peers_in_square()
|
||||
self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':3})
|
||||
|
||||
@needs_test_with_all_chacha20_implementations
|
||||
def test_fail_pending_htlcs_on_shutdown(self):
|
||||
"""Alice tries to pay Dave via MPP. Dave receives some HTLCs but not all.
|
||||
Dave shuts down (stops wallet).
|
||||
We test if Dave fails the pending HTLCs during shutdown.
|
||||
"""
|
||||
graph = self.prepare_chans_and_peers_in_square()
|
||||
self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL))
|
||||
self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL))
|
||||
amount_to_pay = 600_000_000_000
|
||||
peers = graph.all_peers()
|
||||
graph.w_d.MPP_EXPIRY = 120
|
||||
graph.w_d.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3
|
||||
async def pay():
|
||||
graph.w_d.features |= LnFeatures.BASIC_MPP_OPT
|
||||
graph.w_b.enable_htlc_forwarding.clear() # Bob will hold forwarded HTLCs
|
||||
assert graph.w_a.network.channel_db is not None
|
||||
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay)
|
||||
try:
|
||||
async with timeout_after(0.5):
|
||||
result, log = await graph.w_a.pay_invoice(pay_req, attempts=1)
|
||||
except TaskTimeout:
|
||||
# by now Dave hopefully received some HTLCs:
|
||||
self.assertTrue(len(graph.chan_dc.hm.htlcs(LOCAL)) > 0)
|
||||
self.assertTrue(len(graph.chan_dc.hm.htlcs(REMOTE)) > 0)
|
||||
else:
|
||||
self.fail(f"pay_invoice finished but was not supposed to. result={result}")
|
||||
await graph.w_d.stop()
|
||||
# Dave is supposed to have failed the pending incomplete MPP HTLCs
|
||||
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(LOCAL)))
|
||||
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(REMOTE)))
|
||||
raise TestSuccess()
|
||||
|
||||
async def f():
|
||||
async with TaskGroup() as group:
|
||||
for peer in peers:
|
||||
await group.spawn(peer._message_loop())
|
||||
await group.spawn(peer.htlc_switch())
|
||||
await asyncio.sleep(0.2)
|
||||
await group.spawn(pay())
|
||||
|
||||
with self.assertRaises(TestSuccess):
|
||||
run(f())
|
||||
|
||||
@needs_test_with_all_chacha20_implementations
|
||||
def test_close(self):
|
||||
alice_channel, bob_channel = create_test_channels()
|
||||
|
||||
Reference in New Issue
Block a user