Persist MPP resolution status in wallet file.
If we accept a MPP and we forward the payment (trampoline or swap), we need to persist the payment accepted status, or we might wrongly release htlcs on the next restart. lnworker.received_mpp_htlcs used to be cleaned up in maybe_cleanup_forwarding, which only applies to forwarded payments. However, since we now persist this dict, we need to clean it up also in the case of payments received by us. This part of maybe_cleanup_forwarding has been migrated to lnworker.maybe_cleanup_mpp
This commit is contained in:
@@ -2750,6 +2750,7 @@ class Peer(Logger):
|
||||
# return payment_key so this branch will not be executed again
|
||||
return None, payment_key, None
|
||||
elif preimage:
|
||||
self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc)
|
||||
return preimage, None, None
|
||||
else:
|
||||
# we are waiting for mpp consolidation or preimage
|
||||
@@ -2761,7 +2762,10 @@ class Peer(Logger):
|
||||
preimage = self.lnworker.get_preimage(payment_hash)
|
||||
error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key)
|
||||
if error_bytes or error_reason or preimage:
|
||||
self.lnworker.maybe_cleanup_forwarding(payment_key, chan.get_scid_or_local_alias(), htlc)
|
||||
cleanup_keys = self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc)
|
||||
is_htlc_key = ':' in payment_key
|
||||
if is_htlc_key or payment_key in cleanup_keys:
|
||||
self.lnworker.maybe_cleanup_forwarding(payment_key)
|
||||
if error_bytes:
|
||||
return None, None, error_bytes
|
||||
if error_reason:
|
||||
|
||||
@@ -87,6 +87,7 @@ from .submarine_swaps import HttpSwapManager
|
||||
from .channel_db import ChannelInfo, Policy
|
||||
from .mpp_split import suggest_splits, SplitConfigRating
|
||||
from .trampoline import create_trampoline_route_and_onion, is_legacy_relay
|
||||
from .json_db import stored_in
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .network import Network
|
||||
@@ -169,11 +170,13 @@ class PaymentInfo(NamedTuple):
|
||||
status: int
|
||||
|
||||
|
||||
class RecvMPPResolution(Enum):
|
||||
WAITING = enum.auto()
|
||||
EXPIRED = enum.auto()
|
||||
ACCEPTED = enum.auto()
|
||||
FAILED = enum.auto()
|
||||
# Note: these states are persisted in the wallet file.
|
||||
# Do not modify them without performing a wallet db upgrade
|
||||
class RecvMPPResolution(IntEnum):
|
||||
WAITING = 0
|
||||
EXPIRED = 1
|
||||
ACCEPTED = 2
|
||||
FAILED = 3
|
||||
|
||||
|
||||
class ReceivedMPPStatus(NamedTuple):
|
||||
@@ -181,6 +184,13 @@ class ReceivedMPPStatus(NamedTuple):
|
||||
expected_msat: int
|
||||
htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
|
||||
|
||||
@stored_in('received_mpp_htlcs', tuple)
|
||||
def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus':
|
||||
htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid,x) in htlc_list])
|
||||
return ReceivedMPPStatus(
|
||||
resolution=RecvMPPResolution(resolution),
|
||||
expected_msat=expected_msat,
|
||||
htlc_set=htlc_set)
|
||||
|
||||
SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id
|
||||
|
||||
@@ -851,7 +861,7 @@ class LNWallet(LNWorker):
|
||||
|
||||
self._paysessions = dict() # type: Dict[bytes, PaySession]
|
||||
self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
|
||||
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
|
||||
self.received_mpp_htlcs = self.db.get_dict('received_mpp_htlcs') # type: Dict[str, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
|
||||
|
||||
# detect inflight payments
|
||||
self.inflight_payments = set() # (not persisted) keys of invoices that are in PR_INFLIGHT state
|
||||
@@ -2192,7 +2202,7 @@ class LNWallet(LNWorker):
|
||||
payment_keys = [self._get_payment_key(x) for x in hash_list]
|
||||
self.payment_bundles.append(payment_keys)
|
||||
|
||||
def get_payment_bundle(self, payment_key):
|
||||
def get_payment_bundle(self, payment_key: bytes) -> Sequence[bytes]:
|
||||
for key_list in self.payment_bundles:
|
||||
if payment_key in key_list:
|
||||
return key_list
|
||||
@@ -2259,7 +2269,7 @@ class LNWallet(LNWorker):
|
||||
payment_key = payment_hash + payment_secret
|
||||
self.update_mpp_with_received_htlc(
|
||||
payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat)
|
||||
mpp_resolution = self.received_mpp_htlcs[payment_key].resolution
|
||||
mpp_resolution = self.received_mpp_htlcs[payment_key.hex()].resolution
|
||||
# if still waiting, calc resolution now:
|
||||
if mpp_resolution == RecvMPPResolution.WAITING:
|
||||
bundle = self.get_payment_bundle(payment_key)
|
||||
@@ -2280,7 +2290,7 @@ class LNWallet(LNWorker):
|
||||
# save resolution, if any.
|
||||
if mpp_resolution != RecvMPPResolution.WAITING:
|
||||
for pkey in payment_keys:
|
||||
if pkey in self.received_mpp_htlcs:
|
||||
if pkey.hex() in self.received_mpp_htlcs:
|
||||
self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution)
|
||||
|
||||
return mpp_resolution
|
||||
@@ -2294,7 +2304,7 @@ class LNWallet(LNWorker):
|
||||
expected_msat: int,
|
||||
):
|
||||
# add new htlc to set
|
||||
mpp_status = self.received_mpp_htlcs.get(payment_key)
|
||||
mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
|
||||
if mpp_status is None:
|
||||
mpp_status = ReceivedMPPStatus(
|
||||
resolution=RecvMPPResolution.WAITING,
|
||||
@@ -2308,47 +2318,46 @@ class LNWallet(LNWorker):
|
||||
key = (scid, htlc)
|
||||
if key not in mpp_status.htlc_set:
|
||||
mpp_status.htlc_set.add(key) # side-effecting htlc_set
|
||||
self.received_mpp_htlcs[payment_key] = mpp_status
|
||||
self.received_mpp_htlcs[payment_key.hex()] = mpp_status
|
||||
|
||||
def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution):
|
||||
mpp_status = self.received_mpp_htlcs[payment_key]
|
||||
self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=resolution)
|
||||
mpp_status = self.received_mpp_htlcs[payment_key.hex()]
|
||||
self.logger.info(f'set_mpp_resolution {resolution.name} {len(mpp_status.htlc_set)} {payment_key.hex()}')
|
||||
self.received_mpp_htlcs[payment_key.hex()] = mpp_status._replace(resolution=resolution)
|
||||
|
||||
def is_mpp_amount_reached(self, payment_key: bytes) -> bool:
|
||||
mpp_status = self.received_mpp_htlcs.get(payment_key)
|
||||
mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
|
||||
if not mpp_status:
|
||||
return False
|
||||
total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set])
|
||||
return total >= mpp_status.expected_msat
|
||||
|
||||
def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int:
|
||||
mpp_status = self.received_mpp_htlcs.get(payment_key)
|
||||
mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
|
||||
if not mpp_status:
|
||||
return int(time.time())
|
||||
return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set])
|
||||
|
||||
def maybe_cleanup_forwarding(
|
||||
def maybe_cleanup_mpp(
|
||||
self,
|
||||
payment_key_hex: str,
|
||||
short_channel_id: ShortChannelID,
|
||||
htlc: UpdateAddHtlc,
|
||||
) -> None:
|
||||
|
||||
is_htlc_key = ':' in payment_key_hex
|
||||
if not is_htlc_key:
|
||||
payment_key = bytes.fromhex(payment_key_hex)
|
||||
mpp_status = self.received_mpp_htlcs.get(payment_key)
|
||||
if not mpp_status or mpp_status.resolution == RecvMPPResolution.WAITING:
|
||||
# After restart, self.received_mpp_htlcs needs to be reconstructed
|
||||
self.logger.info(f'maybe_cleanup_forwarding: mpp_status not ready')
|
||||
return
|
||||
htlc_key = (short_channel_id, htlc)
|
||||
) -> Sequence[str]:
|
||||
htlc_key = (short_channel_id, htlc)
|
||||
cleanup_keys = []
|
||||
for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()):
|
||||
if htlc_key not in mpp_status.htlc_set:
|
||||
continue
|
||||
assert mpp_status.resolution != RecvMPPResolution.WAITING
|
||||
self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}')
|
||||
mpp_status.htlc_set.remove(htlc_key) # side-effecting htlc_set
|
||||
if mpp_status.htlc_set:
|
||||
return
|
||||
self.logger.info('cleaning up mpp')
|
||||
self.received_mpp_htlcs.pop(payment_key)
|
||||
if len(mpp_status.htlc_set) == 0:
|
||||
self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}')
|
||||
self.received_mpp_htlcs.pop(payment_key_hex)
|
||||
cleanup_keys.append(payment_key_hex)
|
||||
return cleanup_keys
|
||||
|
||||
def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None:
|
||||
self.active_forwardings.pop(payment_key_hex, None)
|
||||
self.forwarding_failures.pop(payment_key_hex, None)
|
||||
|
||||
|
||||
@@ -316,6 +316,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding
|
||||
current_target_feerate_per_kw = LNWallet.current_target_feerate_per_kw
|
||||
current_low_feerate_per_kw = LNWallet.current_low_feerate_per_kw
|
||||
maybe_cleanup_mpp = LNWallet.maybe_cleanup_mpp
|
||||
|
||||
|
||||
class MockTransport:
|
||||
@@ -1741,6 +1742,7 @@ class TestPeerForwarding(TestPeer):
|
||||
):
|
||||
alice_w = graph.workers['alice']
|
||||
bob_w = graph.workers['bob']
|
||||
carol_w = graph.workers['carol']
|
||||
dave_w = graph.workers['dave']
|
||||
if mpp_invoice:
|
||||
dave_w.features |= LnFeatures.BASIC_MPP_OPT
|
||||
@@ -1762,6 +1764,12 @@ class TestPeerForwarding(TestPeer):
|
||||
await asyncio.sleep(2)
|
||||
if result:
|
||||
self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash))
|
||||
# check mpp is cleaned up
|
||||
async with OldTaskGroup() as g:
|
||||
for peer in peers:
|
||||
await g.spawn(peer.wait_one_htlc_switch_iteration())
|
||||
for peer in peers:
|
||||
self.assertEqual(len(peer.lnworker.received_mpp_htlcs), 0)
|
||||
raise PaymentDone()
|
||||
elif len(log) == 1 and log[0].failure_msg.code == OnionFailureCode.MPP_TIMEOUT:
|
||||
raise PaymentTimeout()
|
||||
|
||||
Reference in New Issue
Block a user