diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 7d034a06d..dafff4db4 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1791,13 +1791,25 @@ class Peer(Logger): payment_secret_from_onion = None if total_msat > amt_to_forward: - mpp_status = self.lnworker.check_received_mpp_htlc(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat) - if mpp_status is None: + from .lnworker import RecvMPPResolution + mpp_resolution = self.lnworker.check_mpp_status( + payment_secret=payment_secret_from_onion, + short_channel_id=chan.short_channel_id, + htlc=htlc, + expected_msat=total_msat, + ) + if mpp_resolution == RecvMPPResolution.WAITING: return None, None - if mpp_status is False: + elif mpp_resolution == RecvMPPResolution.EXPIRED: log_fail_reason(f"MPP_TIMEOUT") raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') - assert mpp_status is True + elif mpp_resolution == RecvMPPResolution.FAILED: + log_fail_reason(f"mpp_resolution is FAILED") + raise exc_incorrect_or_unknown_pd + elif mpp_resolution == RecvMPPResolution.ACCEPTED: + pass # continue + else: + raise Exception(f"unexpected {mpp_resolution=}") # if there is a trampoline_onion, maybe_fulfill_htlc will be called again if processed_onion.trampoline_onion_packet: diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 44c5fb0f5..e4e7adc11 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -8,7 +8,8 @@ from decimal import Decimal import random import time import operator -from enum import IntEnum +import enum +from enum import IntEnum, Enum from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator, DefaultDict) import threading @@ -167,6 +168,19 @@ class PaymentInfo(NamedTuple): status: int +class RecvMPPResolution(Enum): + WAITING = enum.auto() + EXPIRED = enum.auto() + ACCEPTED = enum.auto() + FAILED = enum.auto() + + +class ReceivedMPPStatus(NamedTuple): + resolution: RecvMPPResolution + expected_msat: int + htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] + + class ErrorAddingPeer(Exception): pass @@ -657,8 +671,8 @@ class LNWallet(LNWorker): self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] self.sent_htlcs_info = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level - self.sent_buckets = dict() # payment_secret -> (amount_sent, amount_failed) - self.received_mpp_htlcs = dict() # RHASH -> mpp_status, htlc_set + self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) + self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self) # detect inflight payments @@ -1397,13 +1411,14 @@ class LNWallet(LNWorker): key = (payment_hash, short_channel_id, htlc.htlc_id) self.sent_htlcs_info[key] = route, payment_secret, amount_msat, total_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route + payment_key = payment_hash + payment_secret # if we sent MPP to a trampoline, add item to sent_buckets if self.uses_trampoline() and amount_msat != total_msat: - if payment_secret not in self.sent_buckets: - self.sent_buckets[payment_secret] = (0, 0) - amount_sent, amount_failed = self.sent_buckets[payment_secret] + if payment_key not in self.sent_buckets: + self.sent_buckets[payment_key] = (0, 0) + amount_sent, amount_failed = self.sent_buckets[payment_key] amount_sent += amount_receiver_msat - self.sent_buckets[payment_secret] = amount_sent, amount_failed + self.sent_buckets[payment_key] = amount_sent, amount_failed if self.network.path_finder: # add inflight htlcs to liquidity hints self.network.path_finder.update_inflight_htlcs(route, add_htlcs=True) @@ -1879,33 +1894,91 @@ class LNWallet(LNWorker): if write_to_disk: self.wallet.save_db() - def check_received_mpp_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]: - """ return MPP status: True (accepted), False (expired) or None """ + def check_mpp_status( + self, *, + payment_secret: bytes, + short_channel_id: ShortChannelID, + htlc: UpdateAddHtlc, + expected_msat: int, + ) -> RecvMPPResolution: payment_hash = htlc.payment_hash - is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set())) - if self.get_payment_status(payment_hash) == PR_PAID: - # payment_status is persisted - is_accepted = True - is_expired = False - key = (short_channel_id, htlc) - if key not in htlc_set: - htlc_set.add(key) - if not is_accepted and not is_expired: - total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) - first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set]) - if self.stopping_soon: - is_expired = True # try to time out pending HTLCs before shutting down + payment_key = payment_hash + payment_secret + self.update_mpp_with_received_htlc( + payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat) + mpp_resolution = self.received_mpp_htlcs[payment_key].resolution + if mpp_resolution == RecvMPPResolution.WAITING: + first_timestamp = self.get_first_timestamp_of_mpp(payment_key) + if self.get_payment_status(payment_hash) == PR_PAID: + mpp_resolution = RecvMPPResolution.ACCEPTED + elif self.stopping_soon: + # try to time out pending HTLCs before shutting down + mpp_resolution = RecvMPPResolution.EXPIRED + elif self.is_mpp_amount_reached(payment_key): + mpp_resolution = RecvMPPResolution.ACCEPTED elif time.time() - first_timestamp > self.MPP_EXPIRY: - is_expired = True - elif total == expected_msat: - is_accepted = True - if is_accepted or is_expired: - htlc_set.remove(key) - if len(htlc_set) > 0: - self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, htlc_set - elif payment_secret in self.received_mpp_htlcs: - self.received_mpp_htlcs.pop(payment_secret) - return True if is_accepted else (False if is_expired else None) + mpp_resolution = RecvMPPResolution.EXPIRED + + if mpp_resolution != RecvMPPResolution.WAITING: + self.set_mpp_resolution(payment_key=payment_key, resolution=mpp_resolution) + + self.maybe_cleanup_mpp_status(payment_key, short_channel_id, htlc) + return mpp_resolution + + def update_mpp_with_received_htlc( + self, + *, + payment_key: bytes, + scid: ShortChannelID, + htlc: UpdateAddHtlc, + expected_msat: int, + ): + # add new htlc to set + mpp_status = self.received_mpp_htlcs.get(payment_key) + if mpp_status is None: + mpp_status = ReceivedMPPStatus( + resolution=RecvMPPResolution.WAITING, + expected_msat=expected_msat, + htlc_set=set(), + ) + if expected_msat != mpp_status.expected_msat: + self.logger.info( + f"marking received mpp as failed. inconsistent total_msats in bucket. {payment_key.hex()=}") + mpp_status = mpp_status._replace(resolution=RecvMPPResolution.FAILED) + key = (scid, htlc) + if key not in mpp_status.htlc_set: + mpp_status.htlc_set.add(key) # side-effecting htlc_set + self.received_mpp_htlcs[payment_key] = mpp_status + + def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution): + mpp_status = self.received_mpp_htlcs[payment_key] + self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=resolution) + + def is_mpp_amount_reached(self, payment_key: bytes) -> bool: + mpp_status = self.received_mpp_htlcs.get(payment_key) + if not mpp_status: + return False + total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set]) + return total >= mpp_status.expected_msat + + def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int: + mpp_status = self.received_mpp_htlcs.get(payment_key) + if not mpp_status: + return int(time.time()) + return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set]) + + def maybe_cleanup_mpp_status( + self, + payment_key: bytes, + short_channel_id: ShortChannelID, + htlc: UpdateAddHtlc, + ) -> None: + mpp_status = self.received_mpp_htlcs[payment_key] + if mpp_status.resolution == RecvMPPResolution.WAITING: + return + key = (short_channel_id, htlc) + mpp_status.htlc_set.remove(key) # side-effecting htlc_set + if not mpp_status.htlc_set and payment_key in self.received_mpp_htlcs: + self.received_mpp_htlcs.pop(payment_key) def get_payment_status(self, payment_hash: bytes) -> int: info = self.get_payment_info(payment_hash) @@ -2012,10 +2085,11 @@ class LNWallet(LNWorker): self.logger.info(f"htlc_failed {failure_message}") # check sent_buckets if we use trampoline - if self.uses_trampoline() and payment_secret in self.sent_buckets: - amount_sent, amount_failed = self.sent_buckets[payment_secret] + payment_key = payment_hash + payment_secret + if self.uses_trampoline() and payment_key in self.sent_buckets: + amount_sent, amount_failed = self.sent_buckets[payment_key] amount_failed += amount_receiver_msat - self.sent_buckets[payment_secret] = amount_sent, amount_failed + self.sent_buckets[payment_key] = amount_sent, amount_failed if amount_sent != amount_failed: self.logger.info('bucket still active...') return diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 095165469..ecb212972 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -245,7 +245,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): set_request_status = LNWallet.set_request_status set_payment_status = LNWallet.set_payment_status get_payment_status = LNWallet.get_payment_status - check_received_mpp_htlc = LNWallet.check_received_mpp_htlc + check_mpp_status = LNWallet.check_mpp_status htlc_fulfilled = LNWallet.htlc_fulfilled htlc_failed = LNWallet.htlc_failed save_preimage = LNWallet.save_preimage @@ -273,6 +273,11 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): _on_maybe_forwarded_htlc_resolved = LNWallet._on_maybe_forwarded_htlc_resolved _force_close_channel = LNWallet._force_close_channel suggest_splits = LNWallet.suggest_splits + update_mpp_with_received_htlc = LNWallet.update_mpp_with_received_htlc + set_mpp_resolution = LNWallet.set_mpp_resolution + is_mpp_amount_reached = LNWallet.is_mpp_amount_reached + get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp + maybe_cleanup_mpp_status = LNWallet.maybe_cleanup_mpp_status class MockTransport: