Merge pull request #7099 from SomberNight/202103_fail_pending_htlcs_on_shutdown
fail pending htlcs on shutdown
This commit is contained in:
@@ -359,6 +359,55 @@ class HTLCManager:
|
||||
return False
|
||||
return ctns[ctx_owner] <= self.ctn_oldest_unrevoked(ctx_owner)
|
||||
|
||||
@with_lock
|
||||
def is_htlc_irrevocably_removed_yet(
|
||||
self,
|
||||
*,
|
||||
ctx_owner: HTLCOwner = None,
|
||||
htlc_proposer: HTLCOwner,
|
||||
htlc_id: int,
|
||||
) -> bool:
|
||||
"""Returns whether the removal of an htlc was irrevocably committed to `ctx_owner's` ctx.
|
||||
The removal can either be a fulfill/settle or a fail; they are not distinguished.
|
||||
If `ctx_owner` is None, both parties' ctxs are checked.
|
||||
"""
|
||||
in_local = self._is_htlc_irrevocably_removed_yet(
|
||||
ctx_owner=LOCAL, htlc_proposer=htlc_proposer, htlc_id=htlc_id)
|
||||
in_remote = self._is_htlc_irrevocably_removed_yet(
|
||||
ctx_owner=REMOTE, htlc_proposer=htlc_proposer, htlc_id=htlc_id)
|
||||
if ctx_owner is None:
|
||||
return in_local and in_remote
|
||||
elif ctx_owner == LOCAL:
|
||||
return in_local
|
||||
elif ctx_owner == REMOTE:
|
||||
return in_remote
|
||||
else:
|
||||
raise Exception(f"unexpected ctx_owner: {ctx_owner!r}")
|
||||
|
||||
@with_lock
|
||||
def _is_htlc_irrevocably_removed_yet(
|
||||
self,
|
||||
*,
|
||||
ctx_owner: HTLCOwner,
|
||||
htlc_proposer: HTLCOwner,
|
||||
htlc_id: int,
|
||||
) -> bool:
|
||||
htlc_id = int(htlc_id)
|
||||
if htlc_id >= self.get_next_htlc_id(htlc_proposer):
|
||||
return False
|
||||
if htlc_id in self.log[htlc_proposer]['settles']:
|
||||
ctn_of_settle = self.log[htlc_proposer]['settles'][htlc_id][ctx_owner]
|
||||
else:
|
||||
ctn_of_settle = None
|
||||
if htlc_id in self.log[htlc_proposer]['fails']:
|
||||
ctn_of_fail = self.log[htlc_proposer]['fails'][htlc_id][ctx_owner]
|
||||
else:
|
||||
ctn_of_fail = None
|
||||
ctn_of_rm = ctn_of_settle or ctn_of_fail or None
|
||||
if ctn_of_rm is None:
|
||||
return False
|
||||
return ctn_of_rm <= self.ctn_oldest_unrevoked(ctx_owner)
|
||||
|
||||
@with_lock
|
||||
def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,
|
||||
ctn: int = None) -> Dict[int, UpdateAddHtlc]:
|
||||
|
||||
@@ -9,11 +9,12 @@ from collections import OrderedDict, defaultdict
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union
|
||||
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set
|
||||
from datetime import datetime
|
||||
import functools
|
||||
|
||||
import aiorpcx
|
||||
from aiorpcx import TaskGroup
|
||||
|
||||
from .crypto import sha256, sha256d
|
||||
from . import bitcoin, util
|
||||
@@ -74,6 +75,7 @@ class Peer(Logger):
|
||||
self._sent_init = False # type: bool
|
||||
self._received_init = False # type: bool
|
||||
self.initialized = asyncio.Future()
|
||||
self.got_disconnected = asyncio.Event()
|
||||
self.querying = asyncio.Event()
|
||||
self.transport = transport
|
||||
self.pubkey = pubkey # remote pubkey
|
||||
@@ -98,6 +100,11 @@ class Peer(Logger):
|
||||
self.orphan_channel_updates = OrderedDict()
|
||||
Logger.__init__(self)
|
||||
self.taskgroup = SilentTaskGroup()
|
||||
# HTLCs offered by REMOTE, that we started removing but are still active:
|
||||
self.received_htlcs_pending_removal = set() # type: Set[Tuple[Channel, int]]
|
||||
self.received_htlc_removed_event = asyncio.Event()
|
||||
self._htlc_switch_iterstart_event = asyncio.Event()
|
||||
self._htlc_switch_iterdone_event = asyncio.Event()
|
||||
|
||||
def send_message(self, message_name: str, **kwargs):
|
||||
assert type(message_name) is str
|
||||
@@ -492,6 +499,7 @@ class Peer(Logger):
|
||||
except:
|
||||
pass
|
||||
self.lnworker.peer_closed(self)
|
||||
self.got_disconnected.set()
|
||||
|
||||
def is_static_remotekey(self):
|
||||
return self.features.supports(LnFeatures.OPTION_STATIC_REMOTEKEY_OPT)
|
||||
@@ -1575,6 +1583,7 @@ class Peer(Logger):
|
||||
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
|
||||
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
|
||||
assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id)
|
||||
self.received_htlcs_pending_removal.add((chan, htlc_id))
|
||||
chan.settle_htlc(preimage, htlc_id)
|
||||
self.send_message(
|
||||
"update_fulfill_htlc",
|
||||
@@ -1585,6 +1594,7 @@ class Peer(Logger):
|
||||
def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes):
|
||||
self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
|
||||
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
|
||||
self.received_htlcs_pending_removal.add((chan, htlc_id))
|
||||
chan.fail_htlc(htlc_id)
|
||||
self.send_message(
|
||||
"update_fail_htlc",
|
||||
@@ -1596,9 +1606,10 @@ class Peer(Logger):
|
||||
def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure):
|
||||
self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
|
||||
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
|
||||
chan.fail_htlc(htlc_id)
|
||||
if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32):
|
||||
raise Exception(f"unexpected reason when sending 'update_fail_malformed_htlc': {reason!r}")
|
||||
self.received_htlcs_pending_removal.add((chan, htlc_id))
|
||||
chan.fail_htlc(htlc_id)
|
||||
self.send_message(
|
||||
"update_fail_malformed_htlc",
|
||||
channel_id=chan.channel_id,
|
||||
@@ -1800,8 +1811,13 @@ class Peer(Logger):
|
||||
async def htlc_switch(self):
|
||||
await self.initialized
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
self._htlc_switch_iterdone_event.set()
|
||||
self._htlc_switch_iterdone_event.clear()
|
||||
await asyncio.sleep(0.1) # TODO maybe make this partly event-driven
|
||||
self._htlc_switch_iterstart_event.set()
|
||||
self._htlc_switch_iterstart_event.clear()
|
||||
self.ping_if_required()
|
||||
self._maybe_cleanup_received_htlcs_pending_removal()
|
||||
for chan_id, chan in self.channels.items():
|
||||
if not chan.can_send_ctx_updates():
|
||||
continue
|
||||
@@ -1853,6 +1869,29 @@ class Peer(Logger):
|
||||
for htlc_id in done:
|
||||
unfulfilled.pop(htlc_id)
|
||||
|
||||
def _maybe_cleanup_received_htlcs_pending_removal(self) -> None:
|
||||
done = set()
|
||||
for chan, htlc_id in self.received_htlcs_pending_removal:
|
||||
if chan.hm.is_htlc_irrevocably_removed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id):
|
||||
done.add((chan, htlc_id))
|
||||
if done:
|
||||
for key in done:
|
||||
self.received_htlcs_pending_removal.remove(key)
|
||||
self.received_htlc_removed_event.set()
|
||||
self.received_htlc_removed_event.clear()
|
||||
|
||||
async def wait_one_htlc_switch_iteration(self) -> None:
|
||||
"""Waits until the HTLC switch does a full iteration or the peer disconnects,
|
||||
whichever happens first.
|
||||
"""
|
||||
async def htlc_switch_iteration():
|
||||
await self._htlc_switch_iterstart_event.wait()
|
||||
await self._htlc_switch_iterdone_event.wait()
|
||||
|
||||
async with TaskGroup(wait=any) as group:
|
||||
await group.spawn(htlc_switch_iteration())
|
||||
await group.spawn(self.got_disconnected.wait())
|
||||
|
||||
async def process_unfulfilled_htlc(
|
||||
self, *,
|
||||
chan: Channel,
|
||||
|
||||
@@ -22,7 +22,7 @@ import urllib.parse
|
||||
|
||||
import dns.resolver
|
||||
import dns.exception
|
||||
from aiorpcx import run_in_thread, TaskGroup, NetAddress
|
||||
from aiorpcx import run_in_thread, TaskGroup, NetAddress, ignore_after
|
||||
|
||||
from . import constants, util
|
||||
from . import keystore
|
||||
@@ -195,6 +195,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
self.features = features
|
||||
self.network = None # type: Optional[Network]
|
||||
self.config = None # type: Optional[SimpleConfig]
|
||||
self.stopping_soon = False # whether we are being shut down
|
||||
|
||||
util.register_callback(self.on_proxy_changed, ['proxy_set'])
|
||||
|
||||
@@ -268,6 +269,8 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
async def _maintain_connectivity(self):
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
if self.stopping_soon:
|
||||
return
|
||||
now = time.time()
|
||||
if len(self._peers) >= NUM_PEERS_TARGET:
|
||||
continue
|
||||
@@ -575,6 +578,7 @@ class LNWallet(LNWorker):
|
||||
|
||||
lnwatcher: Optional['LNWalletWatcher']
|
||||
MPP_EXPIRY = 120
|
||||
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 # seconds
|
||||
|
||||
def __init__(self, wallet: 'Abstract_Wallet', xprv):
|
||||
self.wallet = wallet
|
||||
@@ -707,9 +711,32 @@ class LNWallet(LNWorker):
|
||||
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
|
||||
|
||||
async def stop(self):
|
||||
await super().stop()
|
||||
await self.lnwatcher.stop()
|
||||
self.lnwatcher = None
|
||||
self.stopping_soon = True
|
||||
if self.listen_server: # stop accepting new peers
|
||||
self.listen_server.close()
|
||||
async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
|
||||
await self.wait_for_received_pending_htlcs_to_get_removed()
|
||||
await LNWorker.stop(self)
|
||||
if self.lnwatcher:
|
||||
await self.lnwatcher.stop()
|
||||
self.lnwatcher = None
|
||||
|
||||
async def wait_for_received_pending_htlcs_to_get_removed(self):
|
||||
assert self.stopping_soon is True
|
||||
# We try to fail pending MPP HTLCs, and wait a bit for them to get removed.
|
||||
# Note: even without MPP, if we just failed/fulfilled an HTLC, it is good
|
||||
# to wait a bit for it to become irrevocably removed.
|
||||
# Note: we don't wait for *all htlcs* to get removed, only for those
|
||||
# that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
|
||||
async with TaskGroup() as group:
|
||||
for peer in self.peers.values():
|
||||
await group.spawn(peer.wait_one_htlc_switch_iteration())
|
||||
while True:
|
||||
if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
|
||||
break
|
||||
async with TaskGroup(wait=any) as group:
|
||||
for peer in self.peers.values():
|
||||
await group.spawn(peer.received_htlc_removed_event.wait())
|
||||
|
||||
def peer_closed(self, peer):
|
||||
for chan in self.channels_for_peer(peer.pubkey).values():
|
||||
@@ -1635,7 +1662,9 @@ class LNWallet(LNWorker):
|
||||
if not is_accepted and not is_expired:
|
||||
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
|
||||
first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set])
|
||||
if time.time() - first_timestamp > self.MPP_EXPIRY:
|
||||
if self.stopping_soon:
|
||||
is_expired = True # try to time out pending HTLCs before shutting down
|
||||
elif time.time() - first_timestamp > self.MPP_EXPIRY:
|
||||
is_expired = True
|
||||
elif total == expected_msat:
|
||||
is_accepted = True
|
||||
@@ -1897,6 +1926,8 @@ class LNWallet(LNWorker):
|
||||
async def reestablish_peers_and_channels(self):
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
if self.stopping_soon:
|
||||
return
|
||||
for chan in self.channels.values():
|
||||
if chan.is_closed():
|
||||
continue
|
||||
|
||||
@@ -10,7 +10,7 @@ from concurrent import futures
|
||||
import unittest
|
||||
from typing import Iterable, NamedTuple, Tuple, List
|
||||
|
||||
from aiorpcx import TaskGroup
|
||||
from aiorpcx import TaskGroup, timeout_after, TaskTimeout
|
||||
|
||||
from electrum import bitcoin
|
||||
from electrum import constants
|
||||
@@ -113,7 +113,8 @@ class MockWallet:
|
||||
|
||||
|
||||
class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
|
||||
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
|
||||
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0
|
||||
|
||||
def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name):
|
||||
self.name = name
|
||||
@@ -121,6 +122,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
|
||||
self.node_keypair = local_keypair
|
||||
self.network = MockNetwork(tx_queue)
|
||||
self.taskgroup = TaskGroup()
|
||||
self.lnwatcher = None
|
||||
self.listen_server = None
|
||||
self._channels = {chan.channel_id: chan for chan in chans}
|
||||
self.payments = {}
|
||||
self.logs = defaultdict(list)
|
||||
@@ -147,6 +151,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
self.trampoline_forwarding_failures = {}
|
||||
self.inflight_payments = set()
|
||||
self.preimages = {}
|
||||
self.stopping_soon = False
|
||||
|
||||
def get_invoice_status(self, key):
|
||||
pass
|
||||
@@ -183,6 +188,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
return self.name
|
||||
|
||||
async def stop(self):
|
||||
await LNWallet.stop(self)
|
||||
if self.channel_db:
|
||||
self.channel_db.stop()
|
||||
await self.channel_db.stopped_event.wait()
|
||||
@@ -215,6 +221,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
_calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice
|
||||
handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
|
||||
is_trampoline_peer = LNWallet.is_trampoline_peer
|
||||
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
|
||||
on_proxy_changed = LNWallet.on_proxy_changed
|
||||
|
||||
|
||||
class MockTransport:
|
||||
@@ -290,13 +298,9 @@ class SquareGraph(NamedTuple):
|
||||
def all_lnworkers(self) -> Iterable[MockLNWallet]:
|
||||
return self.w_a, self.w_b, self.w_c, self.w_d
|
||||
|
||||
async def stop_and_cleanup(self):
|
||||
async with TaskGroup() as group:
|
||||
for lnworker in self.all_lnworkers():
|
||||
await group.spawn(lnworker.stop())
|
||||
|
||||
|
||||
class PaymentDone(Exception): pass
|
||||
class TestSuccess(Exception): pass
|
||||
|
||||
|
||||
class TestPeer(ElectrumTestCase):
|
||||
@@ -836,6 +840,50 @@ class TestPeer(ElectrumTestCase):
|
||||
graph = self.prepare_chans_and_peers_in_square()
|
||||
self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':3})
|
||||
|
||||
@needs_test_with_all_chacha20_implementations
|
||||
def test_fail_pending_htlcs_on_shutdown(self):
|
||||
"""Alice tries to pay Dave via MPP. Dave receives some HTLCs but not all.
|
||||
Dave shuts down (stops wallet).
|
||||
We test if Dave fails the pending HTLCs during shutdown.
|
||||
"""
|
||||
graph = self.prepare_chans_and_peers_in_square()
|
||||
self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL))
|
||||
self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL))
|
||||
amount_to_pay = 600_000_000_000
|
||||
peers = graph.all_peers()
|
||||
graph.w_d.MPP_EXPIRY = 120
|
||||
graph.w_d.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3
|
||||
async def pay():
|
||||
graph.w_d.features |= LnFeatures.BASIC_MPP_OPT
|
||||
graph.w_b.enable_htlc_forwarding.clear() # Bob will hold forwarded HTLCs
|
||||
assert graph.w_a.network.channel_db is not None
|
||||
lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay)
|
||||
try:
|
||||
async with timeout_after(0.5):
|
||||
result, log = await graph.w_a.pay_invoice(pay_req, attempts=1)
|
||||
except TaskTimeout:
|
||||
# by now Dave hopefully received some HTLCs:
|
||||
self.assertTrue(len(graph.chan_dc.hm.htlcs(LOCAL)) > 0)
|
||||
self.assertTrue(len(graph.chan_dc.hm.htlcs(REMOTE)) > 0)
|
||||
else:
|
||||
self.fail(f"pay_invoice finished but was not supposed to. result={result}")
|
||||
await graph.w_d.stop()
|
||||
# Dave is supposed to have failed the pending incomplete MPP HTLCs
|
||||
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(LOCAL)))
|
||||
self.assertEqual(0, len(graph.chan_dc.hm.htlcs(REMOTE)))
|
||||
raise TestSuccess()
|
||||
|
||||
async def f():
|
||||
async with TaskGroup() as group:
|
||||
for peer in peers:
|
||||
await group.spawn(peer._message_loop())
|
||||
await group.spawn(peer.htlc_switch())
|
||||
await asyncio.sleep(0.2)
|
||||
await group.spawn(pay())
|
||||
|
||||
with self.assertRaises(TestSuccess):
|
||||
run(f())
|
||||
|
||||
@needs_test_with_all_chacha20_implementations
|
||||
def test_close(self):
|
||||
alice_channel, bob_channel = create_test_channels()
|
||||
|
||||
Reference in New Issue
Block a user