1
0

lnworker: add RecvMPPResolution with "FAILED" state

- add RecvMPPResolution enum for possible states of a pending incoming MPP,
  and use it in check_mpp_status
  - new state: "FAILED", to allow nicely failing back the whole MPP set
- key more things with payment_hash+payment_secret, for consistency
  (just payment_hash is insufficient for trampoline forwarding)
This commit is contained in:
SomberNight
2023-08-04 13:27:05 +00:00
parent a4a184c6f5
commit c5300c9f1c
3 changed files with 131 additions and 40 deletions

View File

@@ -1791,13 +1791,25 @@ class Peer(Logger):
payment_secret_from_onion = None
if total_msat > amt_to_forward:
mpp_status = self.lnworker.check_received_mpp_htlc(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat)
if mpp_status is None:
from .lnworker import RecvMPPResolution
mpp_resolution = self.lnworker.check_mpp_status(
payment_secret=payment_secret_from_onion,
short_channel_id=chan.short_channel_id,
htlc=htlc,
expected_msat=total_msat,
)
if mpp_resolution == RecvMPPResolution.WAITING:
return None, None
if mpp_status is False:
elif mpp_resolution == RecvMPPResolution.EXPIRED:
log_fail_reason(f"MPP_TIMEOUT")
raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
assert mpp_status is True
elif mpp_resolution == RecvMPPResolution.FAILED:
log_fail_reason(f"mpp_resolution is FAILED")
raise exc_incorrect_or_unknown_pd
elif mpp_resolution == RecvMPPResolution.ACCEPTED:
pass # continue
else:
raise Exception(f"unexpected {mpp_resolution=}")
# if there is a trampoline_onion, maybe_fulfill_htlc will be called again
if processed_onion.trampoline_onion_packet:

View File

@@ -8,7 +8,8 @@ from decimal import Decimal
import random
import time
import operator
from enum import IntEnum
import enum
from enum import IntEnum, Enum
from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING,
NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator, DefaultDict)
import threading
@@ -167,6 +168,19 @@ class PaymentInfo(NamedTuple):
status: int
class RecvMPPResolution(Enum):
WAITING = enum.auto()
EXPIRED = enum.auto()
ACCEPTED = enum.auto()
FAILED = enum.auto()
class ReceivedMPPStatus(NamedTuple):
resolution: RecvMPPResolution
expected_msat: int
htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
class ErrorAddingPeer(Exception): pass
@@ -657,8 +671,8 @@ class LNWallet(LNWorker):
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self.sent_htlcs_info = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level
self.sent_buckets = dict() # payment_secret -> (amount_sent, amount_failed)
self.received_mpp_htlcs = dict() # RHASH -> mpp_status, htlc_set
self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed)
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
# detect inflight payments
@@ -1397,13 +1411,14 @@ class LNWallet(LNWorker):
key = (payment_hash, short_channel_id, htlc.htlc_id)
self.sent_htlcs_info[key] = route, payment_secret, amount_msat, total_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route
payment_key = payment_hash + payment_secret
# if we sent MPP to a trampoline, add item to sent_buckets
if self.uses_trampoline() and amount_msat != total_msat:
if payment_secret not in self.sent_buckets:
self.sent_buckets[payment_secret] = (0, 0)
amount_sent, amount_failed = self.sent_buckets[payment_secret]
if payment_key not in self.sent_buckets:
self.sent_buckets[payment_key] = (0, 0)
amount_sent, amount_failed = self.sent_buckets[payment_key]
amount_sent += amount_receiver_msat
self.sent_buckets[payment_secret] = amount_sent, amount_failed
self.sent_buckets[payment_key] = amount_sent, amount_failed
if self.network.path_finder:
# add inflight htlcs to liquidity hints
self.network.path_finder.update_inflight_htlcs(route, add_htlcs=True)
@@ -1879,33 +1894,91 @@ class LNWallet(LNWorker):
if write_to_disk:
self.wallet.save_db()
def check_received_mpp_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]:
""" return MPP status: True (accepted), False (expired) or None """
def check_mpp_status(
self, *,
payment_secret: bytes,
short_channel_id: ShortChannelID,
htlc: UpdateAddHtlc,
expected_msat: int,
) -> RecvMPPResolution:
payment_hash = htlc.payment_hash
is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set()))
if self.get_payment_status(payment_hash) == PR_PAID:
# payment_status is persisted
is_accepted = True
is_expired = False
key = (short_channel_id, htlc)
if key not in htlc_set:
htlc_set.add(key)
if not is_accepted and not is_expired:
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set])
if self.stopping_soon:
is_expired = True # try to time out pending HTLCs before shutting down
payment_key = payment_hash + payment_secret
self.update_mpp_with_received_htlc(
payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat)
mpp_resolution = self.received_mpp_htlcs[payment_key].resolution
if mpp_resolution == RecvMPPResolution.WAITING:
first_timestamp = self.get_first_timestamp_of_mpp(payment_key)
if self.get_payment_status(payment_hash) == PR_PAID:
mpp_resolution = RecvMPPResolution.ACCEPTED
elif self.stopping_soon:
# try to time out pending HTLCs before shutting down
mpp_resolution = RecvMPPResolution.EXPIRED
elif self.is_mpp_amount_reached(payment_key):
mpp_resolution = RecvMPPResolution.ACCEPTED
elif time.time() - first_timestamp > self.MPP_EXPIRY:
is_expired = True
elif total == expected_msat:
is_accepted = True
if is_accepted or is_expired:
htlc_set.remove(key)
if len(htlc_set) > 0:
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, htlc_set
elif payment_secret in self.received_mpp_htlcs:
self.received_mpp_htlcs.pop(payment_secret)
return True if is_accepted else (False if is_expired else None)
mpp_resolution = RecvMPPResolution.EXPIRED
if mpp_resolution != RecvMPPResolution.WAITING:
self.set_mpp_resolution(payment_key=payment_key, resolution=mpp_resolution)
self.maybe_cleanup_mpp_status(payment_key, short_channel_id, htlc)
return mpp_resolution
def update_mpp_with_received_htlc(
self,
*,
payment_key: bytes,
scid: ShortChannelID,
htlc: UpdateAddHtlc,
expected_msat: int,
):
# add new htlc to set
mpp_status = self.received_mpp_htlcs.get(payment_key)
if mpp_status is None:
mpp_status = ReceivedMPPStatus(
resolution=RecvMPPResolution.WAITING,
expected_msat=expected_msat,
htlc_set=set(),
)
if expected_msat != mpp_status.expected_msat:
self.logger.info(
f"marking received mpp as failed. inconsistent total_msats in bucket. {payment_key.hex()=}")
mpp_status = mpp_status._replace(resolution=RecvMPPResolution.FAILED)
key = (scid, htlc)
if key not in mpp_status.htlc_set:
mpp_status.htlc_set.add(key) # side-effecting htlc_set
self.received_mpp_htlcs[payment_key] = mpp_status
def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution):
mpp_status = self.received_mpp_htlcs[payment_key]
self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=resolution)
def is_mpp_amount_reached(self, payment_key: bytes) -> bool:
mpp_status = self.received_mpp_htlcs.get(payment_key)
if not mpp_status:
return False
total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set])
return total >= mpp_status.expected_msat
def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int:
mpp_status = self.received_mpp_htlcs.get(payment_key)
if not mpp_status:
return int(time.time())
return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set])
def maybe_cleanup_mpp_status(
self,
payment_key: bytes,
short_channel_id: ShortChannelID,
htlc: UpdateAddHtlc,
) -> None:
mpp_status = self.received_mpp_htlcs[payment_key]
if mpp_status.resolution == RecvMPPResolution.WAITING:
return
key = (short_channel_id, htlc)
mpp_status.htlc_set.remove(key) # side-effecting htlc_set
if not mpp_status.htlc_set and payment_key in self.received_mpp_htlcs:
self.received_mpp_htlcs.pop(payment_key)
def get_payment_status(self, payment_hash: bytes) -> int:
info = self.get_payment_info(payment_hash)
@@ -2012,10 +2085,11 @@ class LNWallet(LNWorker):
self.logger.info(f"htlc_failed {failure_message}")
# check sent_buckets if we use trampoline
if self.uses_trampoline() and payment_secret in self.sent_buckets:
amount_sent, amount_failed = self.sent_buckets[payment_secret]
payment_key = payment_hash + payment_secret
if self.uses_trampoline() and payment_key in self.sent_buckets:
amount_sent, amount_failed = self.sent_buckets[payment_key]
amount_failed += amount_receiver_msat
self.sent_buckets[payment_secret] = amount_sent, amount_failed
self.sent_buckets[payment_key] = amount_sent, amount_failed
if amount_sent != amount_failed:
self.logger.info('bucket still active...')
return

View File

@@ -245,7 +245,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
set_request_status = LNWallet.set_request_status
set_payment_status = LNWallet.set_payment_status
get_payment_status = LNWallet.get_payment_status
check_received_mpp_htlc = LNWallet.check_received_mpp_htlc
check_mpp_status = LNWallet.check_mpp_status
htlc_fulfilled = LNWallet.htlc_fulfilled
htlc_failed = LNWallet.htlc_failed
save_preimage = LNWallet.save_preimage
@@ -273,6 +273,11 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
_on_maybe_forwarded_htlc_resolved = LNWallet._on_maybe_forwarded_htlc_resolved
_force_close_channel = LNWallet._force_close_channel
suggest_splits = LNWallet.suggest_splits
update_mpp_with_received_htlc = LNWallet.update_mpp_with_received_htlc
set_mpp_resolution = LNWallet.set_mpp_resolution
is_mpp_amount_reached = LNWallet.is_mpp_amount_reached
get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp
maybe_cleanup_mpp_status = LNWallet.maybe_cleanup_mpp_status
class MockTransport: