lnworker: add RecvMPPResolution with "FAILED" state
- add RecvMPPResolution enum for possible states of a pending incoming MPP, and use it in check_mpp_status - new state: "FAILED", to allow nicely failing back the whole MPP set - key more things with payment_hash+payment_secret, for consistency (just payment_hash is insufficient for trampoline forwarding)
This commit is contained in:
@@ -1826,14 +1826,25 @@ class Peer(Logger):
|
|||||||
log_fail_reason(f"'payment_secret' missing from onion")
|
log_fail_reason(f"'payment_secret' missing from onion")
|
||||||
raise exc_incorrect_or_unknown_pd
|
raise exc_incorrect_or_unknown_pd
|
||||||
|
|
||||||
payment_status = self.lnworker.check_mpp_status(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat)
|
from .lnworker import RecvMPPResolution
|
||||||
if payment_status is None:
|
mpp_resolution = self.lnworker.check_mpp_status(
|
||||||
|
payment_secret=payment_secret_from_onion,
|
||||||
|
short_channel_id=chan.short_channel_id,
|
||||||
|
htlc=htlc,
|
||||||
|
expected_msat=total_msat,
|
||||||
|
)
|
||||||
|
if mpp_resolution == RecvMPPResolution.WAITING:
|
||||||
return None, None
|
return None, None
|
||||||
elif payment_status is False:
|
elif mpp_resolution == RecvMPPResolution.EXPIRED:
|
||||||
log_fail_reason(f"MPP_TIMEOUT")
|
log_fail_reason(f"MPP_TIMEOUT")
|
||||||
raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
|
raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
|
||||||
|
elif mpp_resolution == RecvMPPResolution.FAILED:
|
||||||
|
log_fail_reason(f"mpp_resolution is FAILED")
|
||||||
|
raise exc_incorrect_or_unknown_pd
|
||||||
|
elif mpp_resolution == RecvMPPResolution.ACCEPTED:
|
||||||
|
pass # continue
|
||||||
else:
|
else:
|
||||||
assert payment_status is True
|
raise Exception(f"unexpected {mpp_resolution=}")
|
||||||
|
|
||||||
payment_hash = htlc.payment_hash
|
payment_hash = htlc.payment_hash
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ from decimal import Decimal
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import operator
|
import operator
|
||||||
from enum import IntEnum
|
import enum
|
||||||
|
from enum import IntEnum, Enum
|
||||||
from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING,
|
from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING,
|
||||||
NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator, DefaultDict, Callable)
|
NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator, DefaultDict, Callable)
|
||||||
import threading
|
import threading
|
||||||
@@ -167,9 +168,15 @@ class PaymentInfo(NamedTuple):
|
|||||||
status: int
|
status: int
|
||||||
|
|
||||||
|
|
||||||
|
class RecvMPPResolution(Enum):
|
||||||
|
WAITING = enum.auto()
|
||||||
|
EXPIRED = enum.auto()
|
||||||
|
ACCEPTED = enum.auto()
|
||||||
|
FAILED = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
class ReceivedMPPStatus(NamedTuple):
|
class ReceivedMPPStatus(NamedTuple):
|
||||||
is_expired: bool
|
resolution: RecvMPPResolution
|
||||||
is_accepted: bool
|
|
||||||
expected_msat: int
|
expected_msat: int
|
||||||
htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
|
htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
|
||||||
|
|
||||||
@@ -673,8 +680,8 @@ class LNWallet(LNWorker):
|
|||||||
|
|
||||||
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
|
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
|
||||||
self.sent_htlcs_info = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level
|
self.sent_htlcs_info = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level
|
||||||
self.sent_buckets = dict() # payment_secret -> (amount_sent, amount_failed)
|
self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed)
|
||||||
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_secret -> ReceivedMPPStatus
|
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
|
||||||
|
|
||||||
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
|
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
|
||||||
# detect inflight payments
|
# detect inflight payments
|
||||||
@@ -1418,13 +1425,14 @@ class LNWallet(LNWorker):
|
|||||||
|
|
||||||
key = (payment_hash, short_channel_id, htlc.htlc_id)
|
key = (payment_hash, short_channel_id, htlc.htlc_id)
|
||||||
self.sent_htlcs_info[key] = route, payment_secret, amount_msat, total_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route
|
self.sent_htlcs_info[key] = route, payment_secret, amount_msat, total_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route
|
||||||
|
payment_key = payment_hash + payment_secret
|
||||||
# if we sent MPP to a trampoline, add item to sent_buckets
|
# if we sent MPP to a trampoline, add item to sent_buckets
|
||||||
if self.uses_trampoline() and amount_msat != total_msat:
|
if self.uses_trampoline() and amount_msat != total_msat:
|
||||||
if payment_secret not in self.sent_buckets:
|
if payment_key not in self.sent_buckets:
|
||||||
self.sent_buckets[payment_secret] = (0, 0)
|
self.sent_buckets[payment_key] = (0, 0)
|
||||||
amount_sent, amount_failed = self.sent_buckets[payment_secret]
|
amount_sent, amount_failed = self.sent_buckets[payment_key]
|
||||||
amount_sent += amount_receiver_msat
|
amount_sent += amount_receiver_msat
|
||||||
self.sent_buckets[payment_secret] = amount_sent, amount_failed
|
self.sent_buckets[payment_key] = amount_sent, amount_failed
|
||||||
if self.network.path_finder:
|
if self.network.path_finder:
|
||||||
# add inflight htlcs to liquidity hints
|
# add inflight htlcs to liquidity hints
|
||||||
self.network.path_finder.update_inflight_htlcs(route, add_htlcs=True)
|
self.network.path_finder.update_inflight_htlcs(route, add_htlcs=True)
|
||||||
@@ -1867,6 +1875,14 @@ class LNWallet(LNWorker):
|
|||||||
def get_payment_secret(self, payment_hash):
|
def get_payment_secret(self, payment_hash):
|
||||||
return sha256(sha256(self.payment_secret_key) + payment_hash)
|
return sha256(sha256(self.payment_secret_key) + payment_hash)
|
||||||
|
|
||||||
|
def _get_payment_key(self, payment_hash: bytes) -> bytes:
|
||||||
|
"""Return payment bucket key.
|
||||||
|
We bucket htlcs based on payment_hash+payment_secret. payment_secret is included
|
||||||
|
as it changes over a trampoline path (in the outer onion), and these paths can overlap.
|
||||||
|
"""
|
||||||
|
payment_secret = self.get_payment_secret(payment_hash)
|
||||||
|
return payment_hash + payment_secret
|
||||||
|
|
||||||
def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes:
|
def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes:
|
||||||
payment_preimage = os.urandom(32)
|
payment_preimage = os.urandom(32)
|
||||||
payment_hash = sha256(payment_preimage)
|
payment_hash = sha256(payment_preimage)
|
||||||
@@ -1923,103 +1939,101 @@ class LNWallet(LNWorker):
|
|||||||
self.wallet.save_db()
|
self.wallet.save_db()
|
||||||
|
|
||||||
def check_mpp_status(
|
def check_mpp_status(
|
||||||
self, payment_secret: bytes,
|
self, *,
|
||||||
|
payment_secret: bytes,
|
||||||
short_channel_id: ShortChannelID,
|
short_channel_id: ShortChannelID,
|
||||||
htlc: UpdateAddHtlc,
|
htlc: UpdateAddHtlc,
|
||||||
expected_msat: int,
|
expected_msat: int,
|
||||||
) -> Optional[bool]:
|
) -> RecvMPPResolution:
|
||||||
""" return MPP status: True (accepted), False (expired) or None (waiting)
|
|
||||||
"""
|
|
||||||
payment_hash = htlc.payment_hash
|
payment_hash = htlc.payment_hash
|
||||||
self.update_mpp_with_received_htlc(payment_secret, short_channel_id, htlc, expected_msat)
|
payment_key = payment_hash + payment_secret
|
||||||
is_expired, is_accepted = self.get_mpp_status(payment_secret)
|
self.update_mpp_with_received_htlc(
|
||||||
if not is_accepted and not is_expired:
|
payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat)
|
||||||
|
mpp_resolution = self.received_mpp_htlcs[payment_key].resolution
|
||||||
|
if mpp_resolution == RecvMPPResolution.WAITING:
|
||||||
bundle = self.get_payment_bundle(payment_hash)
|
bundle = self.get_payment_bundle(payment_hash)
|
||||||
if bundle:
|
if bundle:
|
||||||
payment_secrets = [self.get_payment_secret(h) for h in bundle]
|
payment_keys = [self._get_payment_key(h) for h in bundle]
|
||||||
if payment_secret not in payment_secrets:
|
if payment_key not in payment_keys:
|
||||||
# outer trampoline onion secret differs from inner onion
|
# outer trampoline onion secret differs from inner onion
|
||||||
# the latter, not the former, might be part of a bundle
|
# the latter, not the former, might be part of a bundle
|
||||||
payment_secrets = [payment_secret]
|
payment_keys = [payment_key]
|
||||||
else:
|
else:
|
||||||
payment_secrets = [payment_secret]
|
payment_keys = [payment_key]
|
||||||
first_timestamp = min([self.get_first_timestamp_of_mpp(x) for x in payment_secrets])
|
first_timestamp = min([self.get_first_timestamp_of_mpp(pkey) for pkey in payment_keys])
|
||||||
if self.get_payment_status(payment_hash) == PR_PAID:
|
if self.get_payment_status(payment_hash) == PR_PAID:
|
||||||
is_accepted = True
|
mpp_resolution = RecvMPPResolution.ACCEPTED
|
||||||
elif self.stopping_soon:
|
elif self.stopping_soon:
|
||||||
is_expired = True # try to time out pending HTLCs before shutting down
|
# try to time out pending HTLCs before shutting down
|
||||||
elif all([self.is_mpp_amount_reached(x) for x in payment_secrets]):
|
mpp_resolution = RecvMPPResolution.EXPIRED
|
||||||
is_accepted = True
|
elif all([self.is_mpp_amount_reached(pkey) for pkey in payment_keys]):
|
||||||
|
mpp_resolution = RecvMPPResolution.ACCEPTED
|
||||||
elif time.time() - first_timestamp > self.MPP_EXPIRY:
|
elif time.time() - first_timestamp > self.MPP_EXPIRY:
|
||||||
is_expired = True
|
mpp_resolution = RecvMPPResolution.EXPIRED
|
||||||
|
|
||||||
if is_accepted or is_expired:
|
if mpp_resolution != RecvMPPResolution.WAITING:
|
||||||
for x in payment_secrets:
|
for pkey in payment_keys:
|
||||||
if x in self.received_mpp_htlcs:
|
if pkey in self.received_mpp_htlcs:
|
||||||
self.set_mpp_status(x, is_expired, is_accepted)
|
self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution)
|
||||||
|
|
||||||
self.maybe_cleanup_mpp_status(payment_secret, short_channel_id, htlc)
|
self.maybe_cleanup_mpp_status(payment_key, short_channel_id, htlc)
|
||||||
return True if is_accepted else (False if is_expired else None)
|
return mpp_resolution
|
||||||
|
|
||||||
def update_mpp_with_received_htlc(
|
def update_mpp_with_received_htlc(
|
||||||
self,
|
self,
|
||||||
payment_secret: bytes,
|
*,
|
||||||
short_channel_id: ShortChannelID,
|
payment_key: bytes,
|
||||||
|
scid: ShortChannelID,
|
||||||
htlc: UpdateAddHtlc,
|
htlc: UpdateAddHtlc,
|
||||||
expected_msat: int,
|
expected_msat: int,
|
||||||
):
|
):
|
||||||
# add new htlc to set
|
# add new htlc to set
|
||||||
mpp_status = self.received_mpp_htlcs.get(payment_secret)
|
mpp_status = self.received_mpp_htlcs.get(payment_key)
|
||||||
if mpp_status is None:
|
if mpp_status is None:
|
||||||
mpp_status = ReceivedMPPStatus(
|
mpp_status = ReceivedMPPStatus(
|
||||||
is_expired=False,
|
resolution=RecvMPPResolution.WAITING,
|
||||||
is_accepted=False,
|
|
||||||
expected_msat=expected_msat,
|
expected_msat=expected_msat,
|
||||||
htlc_set=set(),
|
htlc_set=set(),
|
||||||
)
|
)
|
||||||
assert expected_msat == mpp_status.expected_msat
|
if expected_msat != mpp_status.expected_msat:
|
||||||
key = (short_channel_id, htlc)
|
self.logger.info(
|
||||||
|
f"marking received mpp as failed. inconsistent total_msats in bucket. {payment_key.hex()=}")
|
||||||
|
mpp_status = mpp_status._replace(resolution=RecvMPPResolution.FAILED)
|
||||||
|
key = (scid, htlc)
|
||||||
if key not in mpp_status.htlc_set:
|
if key not in mpp_status.htlc_set:
|
||||||
mpp_status.htlc_set.add(key) # side-effecting htlc_set
|
mpp_status.htlc_set.add(key) # side-effecting htlc_set
|
||||||
self.received_mpp_htlcs[payment_secret] = mpp_status
|
self.received_mpp_htlcs[payment_key] = mpp_status
|
||||||
|
|
||||||
def get_mpp_status(self, payment_secret: bytes) -> Tuple[bool, bool]:
|
def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution):
|
||||||
mpp_status = self.received_mpp_htlcs[payment_secret]
|
mpp_status = self.received_mpp_htlcs[payment_key]
|
||||||
return mpp_status.is_expired, mpp_status.is_accepted
|
self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=resolution)
|
||||||
|
|
||||||
def set_mpp_status(self, payment_secret: bytes, is_expired: bool, is_accepted: bool):
|
def is_mpp_amount_reached(self, payment_key: bytes) -> bool:
|
||||||
mpp_status = self.received_mpp_htlcs[payment_secret]
|
mpp_status = self.received_mpp_htlcs.get(payment_key)
|
||||||
self.received_mpp_htlcs[payment_secret] = mpp_status._replace(
|
|
||||||
is_expired=is_expired,
|
|
||||||
is_accepted=is_accepted,
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_mpp_amount_reached(self, payment_secret: bytes) -> bool:
|
|
||||||
mpp_status = self.received_mpp_htlcs.get(payment_secret)
|
|
||||||
if not mpp_status:
|
if not mpp_status:
|
||||||
return False
|
return False
|
||||||
total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set])
|
total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set])
|
||||||
return total >= mpp_status.expected_msat
|
return total >= mpp_status.expected_msat
|
||||||
|
|
||||||
def get_first_timestamp_of_mpp(self, payment_secret: bytes) -> int:
|
def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int:
|
||||||
mpp_status = self.received_mpp_htlcs.get(payment_secret)
|
mpp_status = self.received_mpp_htlcs.get(payment_key)
|
||||||
if not mpp_status:
|
if not mpp_status:
|
||||||
return int(time.time())
|
return int(time.time())
|
||||||
return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set])
|
return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set])
|
||||||
|
|
||||||
def maybe_cleanup_mpp_status(
|
def maybe_cleanup_mpp_status(
|
||||||
self,
|
self,
|
||||||
payment_secret: bytes,
|
payment_key: bytes,
|
||||||
short_channel_id: ShortChannelID,
|
short_channel_id: ShortChannelID,
|
||||||
htlc: UpdateAddHtlc,
|
htlc: UpdateAddHtlc,
|
||||||
) -> None:
|
) -> None:
|
||||||
mpp_status = self.received_mpp_htlcs[payment_secret]
|
mpp_status = self.received_mpp_htlcs[payment_key]
|
||||||
if not mpp_status.is_accepted and not mpp_status.is_expired:
|
if mpp_status.resolution == RecvMPPResolution.WAITING:
|
||||||
return
|
return
|
||||||
key = (short_channel_id, htlc)
|
key = (short_channel_id, htlc)
|
||||||
mpp_status.htlc_set.remove(key) # side-effecting htlc_set
|
mpp_status.htlc_set.remove(key) # side-effecting htlc_set
|
||||||
if not mpp_status.htlc_set and payment_secret in self.received_mpp_htlcs:
|
if not mpp_status.htlc_set and payment_key in self.received_mpp_htlcs:
|
||||||
self.received_mpp_htlcs.pop(payment_secret)
|
self.received_mpp_htlcs.pop(payment_key)
|
||||||
|
|
||||||
def get_payment_status(self, payment_hash: bytes) -> int:
|
def get_payment_status(self, payment_hash: bytes) -> int:
|
||||||
info = self.get_payment_info(payment_hash)
|
info = self.get_payment_info(payment_hash)
|
||||||
@@ -2126,10 +2140,11 @@ class LNWallet(LNWorker):
|
|||||||
self.logger.info(f"htlc_failed {failure_message}")
|
self.logger.info(f"htlc_failed {failure_message}")
|
||||||
|
|
||||||
# check sent_buckets if we use trampoline
|
# check sent_buckets if we use trampoline
|
||||||
if self.uses_trampoline() and payment_secret in self.sent_buckets:
|
payment_key = payment_hash + payment_secret
|
||||||
amount_sent, amount_failed = self.sent_buckets[payment_secret]
|
if self.uses_trampoline() and payment_key in self.sent_buckets:
|
||||||
|
amount_sent, amount_failed = self.sent_buckets[payment_key]
|
||||||
amount_failed += amount_receiver_msat
|
amount_failed += amount_receiver_msat
|
||||||
self.sent_buckets[payment_secret] = amount_sent, amount_failed
|
self.sent_buckets[payment_key] = amount_sent, amount_failed
|
||||||
if amount_sent != amount_failed:
|
if amount_sent != amount_failed:
|
||||||
self.logger.info('bucket still active...')
|
self.logger.info('bucket still active...')
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -283,13 +283,13 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
|||||||
add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice
|
add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice
|
||||||
|
|
||||||
update_mpp_with_received_htlc = LNWallet.update_mpp_with_received_htlc
|
update_mpp_with_received_htlc = LNWallet.update_mpp_with_received_htlc
|
||||||
get_mpp_status = LNWallet.get_mpp_status
|
set_mpp_resolution = LNWallet.set_mpp_resolution
|
||||||
set_mpp_status = LNWallet.set_mpp_status
|
|
||||||
is_mpp_amount_reached = LNWallet.is_mpp_amount_reached
|
is_mpp_amount_reached = LNWallet.is_mpp_amount_reached
|
||||||
get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp
|
get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp
|
||||||
maybe_cleanup_mpp_status = LNWallet.maybe_cleanup_mpp_status
|
maybe_cleanup_mpp_status = LNWallet.maybe_cleanup_mpp_status
|
||||||
bundle_payments = LNWallet.bundle_payments
|
bundle_payments = LNWallet.bundle_payments
|
||||||
get_payment_bundle = LNWallet.get_payment_bundle
|
get_payment_bundle = LNWallet.get_payment_bundle
|
||||||
|
_get_payment_key = LNWallet._get_payment_key
|
||||||
|
|
||||||
|
|
||||||
class MockTransport:
|
class MockTransport:
|
||||||
|
|||||||
Reference in New Issue
Block a user