From 0f314d1dd965424d898f4d8c564449ee42c21e97 Mon Sep 17 00:00:00 2001 From: f321x Date: Tue, 9 Sep 2025 13:10:17 +0200 Subject: [PATCH] lnpeer/lnworker: refactor htlc_switch refactor `htlc_switch` to new architecture to make it more robust against partial settlement of htlc sets and increase maintainability. Htlcs are now processed in two steps, first the htlcs are collected into sets from the channels, and potentially failed on their own already. Then a second loop iterates over the htlc sets and finalizes only on whole sets. # Conflicts: # electrum/lnpeer.py --- electrum/commands.py | 2 +- electrum/lnchannel.py | 7 +- electrum/lnonion.py | 60 ++- electrum/lnpeer.py | 860 ++++++++++++++++++++++++------------ electrum/lnsweep.py | 4 + electrum/lnutil.py | 80 +++- electrum/lnworker.py | 327 +++++++++----- electrum/submarine_swaps.py | 15 +- electrum/wallet_db.py | 107 ++++- tests/test_commands.py | 10 +- tests/test_lnpeer.py | 14 +- 11 files changed, 1038 insertions(+), 448 deletions(-) diff --git a/electrum/commands.py b/electrum/commands.py index 94d0d192b..f249f55f9 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -1508,7 +1508,7 @@ class Commands(Logger): payment_key: str = wallet.lnworker._get_payment_key(bfh(payment_hash)).hex() htlc_status = wallet.lnworker.received_mpp_htlcs[payment_key] result["closest_htlc_expiry_height"] = min( - htlc.cltv_abs for _, htlc in htlc_status.htlc_set + mpp_htlc.htlc.cltv_abs for mpp_htlc in htlc_status.htlcs ) elif wallet.lnworker.get_preimage_hex(payment_hash) is not None \ and payment_hash not in wallet.lnworker.dont_settle_htlcs: diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index dd5fedaca..36c2dad7f 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -783,8 +783,8 @@ class Channel(AbstractChannel): self.onion_keys = state['onion_keys'] # type: Dict[int, bytes] self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp'] self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate) - self.unfulfilled_htlcs = state["unfulfilled_htlcs"] # type: Dict[int, Tuple[str, Optional[str]]] - # ^ htlc_id -> onion_packet_hex, forwarding_key + self.unfulfilled_htlcs = state["unfulfilled_htlcs"] # type: Dict[int, Optional[str]] + # ^ htlc_id -> onion_packet_hex self._state = ChannelState[state['state']] self.peer_state = PeerState.DISCONNECTED self._outgoing_channel_update = None # type: Optional[bytes] @@ -1112,6 +1112,7 @@ class Channel(AbstractChannel): if amount_msat <= 0: raise PaymentFailure("HTLC value must be positive") if amount_msat < chan_config.htlc_minimum_msat: + # todo: for incoming htlcs this could be handled more gracefully with `amount_below_minimum` raise PaymentFailure(f'HTLC value too small: {amount_msat} msat') if self.htlc_slots_left(htlc_proposer) == 0: @@ -1226,7 +1227,7 @@ class Channel(AbstractChannel): with self.db_lock: self.hm.recv_htlc(htlc) if onion_packet: - self.unfulfilled_htlcs[htlc.htlc_id] = onion_packet.hex(), None + self.unfulfilled_htlcs[htlc.htlc_id] = onion_packet.hex() self.logger.info("receive_htlc") return htlc diff --git a/electrum/lnonion.py b/electrum/lnonion.py index f624627e1..b57c04792 100644 --- a/electrum/lnonion.py +++ b/electrum/lnonion.py @@ -26,7 +26,8 @@ import io import hashlib from functools import cached_property -from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union, Mapping +from typing import (Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union, + Mapping, Iterator) from enum import IntEnum from dataclasses import dataclass, field, replace from types import MappingProxyType @@ -485,6 +486,55 @@ def process_onion_packet( return ProcessedOnionPacket(are_we_final, hop_data, next_onion_packet, trampoline_onion_packet) +def compare_trampoline_onions( + trampoline_onions: Iterator[Optional[ProcessedOnionPacket]], + *, + exclude_amt_to_fwd: bool = False, +) -> bool: + """ + compare values of trampoline onions payloads and are_we_final. + If we are receiver of a multi trampoline payment amt_to_fwd can differ between the trampoline + parts of the payment, so it needs to be excluded from the comparison when comparing all trampoline + onions of the whole payment (however it can be compared between the onions in a single trampoline part). + """ + try: + first_onion = next(trampoline_onions) + except StopIteration: + raise ValueError("nothing to compare") + + if first_onion is None: + # we don't support mixed mpp sets of htlcs with trampoline onions and regular non-trampoline htlcs. + # In theory this could happen if a sender e.g. uses trampoline as fallback to deliver + # outstanding mpp parts if local pathfinding wasn't successful for the whole payment, + # resulting in a mixed payment. However, it's not even clear if the spec allows for such a constellation. + return all(onion is None for onion in trampoline_onions) + assert isinstance(first_onion, ProcessedOnionPacket), f"{first_onion=}" + + are_we_final = first_onion.are_we_final + payload = first_onion.hop_data.payload + total_msat = first_onion.total_msat + outgoing_cltv = first_onion.outgoing_cltv_value + payment_secret = first_onion.payment_secret + for onion in trampoline_onions: + if onion is None: + return False + assert isinstance(onion, ProcessedOnionPacket), f"{onion=}" + assert onion.trampoline_onion_packet is None, f"{onion=} cannot have trampoline_onion_packet" + if onion.are_we_final != are_we_final: + return False + if not exclude_amt_to_fwd: + if onion.hop_data.payload != payload: + return False + else: + if onion.total_msat != total_msat: + return False + if onion.outgoing_cltv_value != outgoing_cltv: + return False + if onion.payment_secret != payment_secret: + return False + return True + + class FailedToDecodeOnionError(Exception): pass @@ -521,6 +571,14 @@ class OnionRoutingFailure(Exception): payload = None return payload + def to_wire_msg(self, onion_packet: OnionPacket, privkey: bytes, local_height: int) -> bytes: + onion_error = construct_onion_error(self, onion_packet.public_key, privkey, local_height) + error_bytes = obfuscate_onion_error(onion_error, onion_packet.public_key, privkey) + return error_bytes + + +class OnionParsingError(OnionRoutingFailure): pass + def construct_onion_error( error: OnionRoutingFailure, diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 72db12657..377dd375f 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -21,6 +21,7 @@ from electrum_ecc import ecdsa_sig64_from_r_and_s, ecdsa_der_sig_from_ecdsa_sig6 import aiorpcx from aiorpcx import ignore_after +from .lrucache import LRUCache from .crypto import sha256, sha256d, privkey_to_pubkey from . import bitcoin, util from . import constants @@ -31,11 +32,11 @@ from . import transaction from .bitcoin import make_op_return, DummyAddress from .transaction import PartialTxOutput, match_script_against_template, Sighash from .logging import Logger -from .lnrouter import RouteEdge -from .lnonion import (new_onion_packet, OnionFailureCode, calc_hops_data_for_payment, process_onion_packet, - OnionPacket, construct_onion_error, obfuscate_onion_error, OnionRoutingFailure, - ProcessedOnionPacket, UnsupportedOnionPacketVersion, InvalidOnionMac, InvalidOnionPubkey, - OnionFailureCodeMetaFlag) +from . import lnonion +from .lnonion import (OnionFailureCode, OnionPacket, obfuscate_onion_error, + OnionRoutingFailure, ProcessedOnionPacket, UnsupportedOnionPacketVersion, + InvalidOnionMac, InvalidOnionPubkey, OnionFailureCodeMetaFlag, + OnionParsingError) from .lnchannel import Channel, RevokeAndAck, ChannelState, PeerState, ChanCloseOption, CF_ANNOUNCE_CHANNEL from . import lnutil from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConfig, @@ -46,17 +47,15 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConf ln_compare_features, MIN_FINAL_CLTV_DELTA_ACCEPTED, RemoteMisbehaving, ShortChannelID, IncompatibleLightningFeatures, ChannelType, LNProtocolWarning, validate_features, - IncompatibleOrInsaneFeatures, FeeBudgetExceeded, + IncompatibleOrInsaneFeatures, ReceivedMPPStatus, ReceivedMPPHtlc, GossipForwardingMessage, GossipTimestampFilter, channel_id_from_funding_tx, - PaymentFeeBudget, serialize_htlc_key, Keypair, RecvMPPResolution) + serialize_htlc_key, Keypair, RecvMPPResolution) from .lntransport import LNTransport, LNTransportBase, LightningPeerConnectionClosed, HandshakeFailed from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType, FailedToParseMsg from .interface import GracefulDisconnect -from .lnrouter import fee_for_edge_msat from .json_db import StoredDict from .invoices import PR_PAID from .fee_policy import FEE_LN_ETA_TARGET, FEERATE_PER_KW_MIN_RELAY_LIGHTNING -from .trampoline import decode_routing_info if TYPE_CHECKING: from .lnworker import LNGossip, LNWallet @@ -132,6 +131,7 @@ class Peer(Logger, EventListener): self.downstream_htlc_resolved_event = asyncio.Event() self.register_callbacks() self._num_gossip_messages_forwarded = 0 + self._processed_onion_cache = LRUCache(maxsize=100) # type: LRUCache[bytes, ProcessedOnionPacket] def send_message(self, message_name: str, **kwargs): assert util.get_running_loop() == util.get_asyncio_loop(), f"this must be run on the asyncio thread!" @@ -2136,177 +2136,151 @@ class Peer(Logger, EventListener): return payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd - def check_mpp_is_waiting( - self, - *, - payment_secret: bytes, - short_channel_id: ShortChannelID, + def _check_unfulfilled_htlc( + self, *, + chan: Channel, htlc: UpdateAddHtlc, - expected_msat: int, - exc_incorrect_or_unknown_pd: OnionRoutingFailure, - log_fail_reason: Callable[[str], None], - ) -> bool: - mpp_resolution = self.lnworker.check_mpp_status( - payment_secret=payment_secret, - short_channel_id=short_channel_id, - htlc=htlc, - expected_msat=expected_msat, - ) - if mpp_resolution == RecvMPPResolution.WAITING: - return True - elif mpp_resolution == RecvMPPResolution.EXPIRED: - log_fail_reason(f"MPP_TIMEOUT") - raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') - elif mpp_resolution == RecvMPPResolution.FAILED: - log_fail_reason(f"mpp_resolution is FAILED") - raise exc_incorrect_or_unknown_pd - elif mpp_resolution == RecvMPPResolution.COMPLETE: - return False - else: - raise Exception(f"unexpected {mpp_resolution=}") - - def maybe_fulfill_htlc( - self, *, - chan: Channel, - htlc: UpdateAddHtlc, - processed_onion: ProcessedOnionPacket, - outer_onion_payment_secret: bytes = None, # used to group trampoline htlcs for forwarding - onion_packet_bytes: bytes, - already_forwarded: bool = False, - ) -> Tuple[Optional[bytes], Optional[Tuple[str, Callable[[], Awaitable[Optional[str]]]]]]: + processed_onion: ProcessedOnionPacket, + outer_onion_payment_secret: bytes = None, # used to group trampoline htlcs for forwarding + ) -> str: """ - Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded. - Return (preimage, (payment_key, callback)) with at most a single element not None. + Does additional checks on the incoming htlc and return the payment key if the tests pass, + otherwise raises OnionRoutingError which will get the htlc failed. """ - if not processed_onion.are_we_final: - if not self.lnworker.enable_htlc_forwarding: - return None, None - # use the htlc key if we are forwarding - payment_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id) - callback = lambda: self.lnworker.maybe_forward_htlc( - incoming_chan=chan, - htlc=htlc, - processed_onion=processed_onion) - return None, (payment_key, callback) + _log_fail_reason = self._log_htlc_fail_reason_cb(chan.short_channel_id, htlc, processed_onion.hop_data.payload) - def log_fail_reason(reason: str): - self.logger.info( - f"maybe_fulfill_htlc. will FAIL HTLC: chan {chan.short_channel_id}. " - f"{reason}. htlc={str(htlc)}. onion_payload={processed_onion.hop_data.payload}") - - chain = self.network.blockchain() # Check that our blockchain tip is sufficiently recent so that we have an approx idea of the height. # We should not release the preimage for an HTLC that its sender could already time out as # then they might try to force-close and it becomes a race. - if chain.is_tip_stale() and not already_forwarded: - log_fail_reason(f"our chain tip is stale") - raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') + chain = self.network.blockchain() local_height = chain.height() + blocks_to_expiry = max(htlc.cltv_abs - local_height, 0) + if chain.is_tip_stale(): + _log_fail_reason(f"our chain tip is stale: {local_height=}") + raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') + + payment_hash = htlc.payment_hash + if not processed_onion.are_we_final: + if outer_onion_payment_secret: + # this is a trampoline forwarding htlc, multiple incoming trampoline htlcs can be collected + payment_key = (payment_hash + outer_onion_payment_secret).hex() + return payment_key + # this is a regular htlc to forward, it will get its own set of size 1 keyed by htlc_key + # Additional checks required only for forwarding nodes will be done in maybe_forward_htlc(). + payment_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id) + return payment_key # parse parameters and perform checks that are invariant - payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd = self._check_accepted_final_htlc( - chan=chan, - htlc=htlc, - processed_onion=processed_onion, - is_trampoline_onion=bool(outer_onion_payment_secret), - log_fail_reason=log_fail_reason) - - # payment key for final onions - payment_hash = htlc.payment_hash - payment_key = (payment_hash + payment_secret_from_onion).hex() - - if self.check_mpp_is_waiting( - payment_secret=payment_secret_from_onion, - short_channel_id=chan.get_scid_or_local_alias(), + payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd = ( + self._check_accepted_final_htlc( + chan=chan, htlc=htlc, - expected_msat=total_msat, - exc_incorrect_or_unknown_pd=exc_incorrect_or_unknown_pd, - log_fail_reason=log_fail_reason, - ): - return None, None + processed_onion=processed_onion, + is_trampoline_onion=bool(outer_onion_payment_secret), + log_fail_reason=_log_fail_reason, + )) + # trampoline htlcs of which we are the final receiver will first get grouped by the outer + # onions secret to allow grouping a multi-trampoline mpp in different sets. Once a trampoline + # payment part is completed (sum(htlcs) >= (trampoline-)amt_to_forward), its htlcs get moved into + # the htlc set representing the whole payment (payment key derived from trampoline/invoice secret). + payment_key = (payment_hash + (outer_onion_payment_secret or payment_secret_from_onion)).hex() - # TODO check against actual min_final_cltv_expiry_delta from invoice (and give 2-3 blocks of leeway?) - # note: payment_bundles might get split here, e.g. one payment is "already forwarded" and the other is not. - # In practice, for the swap prepayment use case, this does not matter. - if local_height + MIN_FINAL_CLTV_DELTA_ACCEPTED > htlc.cltv_abs and not already_forwarded: - log_fail_reason(f"htlc.cltv_abs is unreasonably close") + if blocks_to_expiry < MIN_FINAL_CLTV_DELTA_ACCEPTED: + # this check should be done here for new htlcs and ongoing on pending sets. + # Here it is done so that invalid received htlcs will never get added to a set, + # so the set still has a chance to succeed until mpp timeout. + _log_fail_reason(f"htlc.cltv_abs is unreasonably close: {htlc.cltv_abs=}, {local_height=}") raise exc_incorrect_or_unknown_pd - # detect callback - # if there is a trampoline_onion, maybe_fulfill_htlc will be called again - # order is important: if we receive a trampoline onion for a hold invoice, we need to peel the onion first. - + # extract trampoline if processed_onion.trampoline_onion_packet: - # TODO: we should check that all trampoline_onions are the same - trampoline_onion = self.process_onion_packet( + trampoline_onion = self._process_incoming_onion_packet( processed_onion.trampoline_onion_packet, payment_hash=payment_hash, - onion_packet_bytes=onion_packet_bytes, is_trampoline=True) - if trampoline_onion.are_we_final: - # trampoline- we are final recipient of HTLC - # note: the returned payment_key will contain the inner payment_secret - return self.maybe_fulfill_htlc( - chan=chan, - htlc=htlc, - processed_onion=trampoline_onion, - outer_onion_payment_secret=payment_secret_from_onion, - onion_packet_bytes=onion_packet_bytes, - already_forwarded=already_forwarded, - ) - else: - callback = lambda: self.lnworker.maybe_forward_trampoline( - payment_hash=payment_hash, - inc_cltv_abs=htlc.cltv_abs, # TODO: use max or enforce same value across mpp parts - outer_onion=processed_onion, - trampoline_onion=trampoline_onion, - fw_payment_key=payment_key) - return None, (payment_key, callback) - # TODO don't accept payments twice for same invoice - # note: we don't check invoice expiry (bolt11 'x' field) on the receiver-side. - # - semantics are weird: would make sense for simple-payment-receives, but not - # if htlc is expected to be pending for a while, e.g. for a hold-invoice. + # compare trampoline onion against outer onion according to: + # https://github.com/lightning/bolts/blob/9938ab3d6160a3ba91f3b0e132858ab14bfe4f81/04-onion-routing.md?plain=1#L547-L553 + if trampoline_onion.are_we_final: + try: + assert not processed_onion.outgoing_cltv_value < trampoline_onion.outgoing_cltv_value + is_mpp = processed_onion.total_msat > processed_onion.amt_to_forward + if is_mpp: + assert not processed_onion.total_msat < trampoline_onion.amt_to_forward + else: + assert not processed_onion.amt_to_forward < trampoline_onion.amt_to_forward + except AssertionError: + _log_fail_reason(f'incorrect trampoline onion {processed_onion=}\n{trampoline_onion=}') + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') + + return self._check_unfulfilled_htlc( + chan=chan, + htlc=htlc, + processed_onion=trampoline_onion, + outer_onion_payment_secret=payment_secret_from_onion, + ) + info = self.lnworker.get_payment_info(payment_hash) if info is None: - log_fail_reason(f"no payment_info found for RHASH {htlc.payment_hash.hex()}") + _log_fail_reason(f"no payment_info found for RHASH {payment_hash.hex()}") + raise exc_incorrect_or_unknown_pd + elif info.status == PR_PAID: + _log_fail_reason(f"invoice already paid: {payment_hash.hex()=}") + raise exc_incorrect_or_unknown_pd + elif blocks_to_expiry < info.min_final_cltv_delta: + _log_fail_reason( + f"min final cltv delta lower than requested: " + f"{payment_hash.hex()=} {htlc.cltv_abs=} {blocks_to_expiry=}" + ) + raise exc_incorrect_or_unknown_pd + elif htlc.timestamp > info.expiration_ts: # the set will get failed too if now > exp_ts + _log_fail_reason(f"not accepting htlc for expired invoice") raise exc_incorrect_or_unknown_pd - preimage = self.lnworker.get_preimage(payment_hash) - expected_payment_secret = self.lnworker.get_payment_secret(htlc.payment_hash) - if payment_secret_from_onion != expected_payment_secret: - log_fail_reason(f'incorrect payment secret {payment_secret_from_onion.hex()} != {expected_payment_secret.hex()}') + expected_payment_secret = self.lnworker.get_payment_secret(payment_hash) + if not util.constant_time_compare(payment_secret_from_onion, expected_payment_secret): + _log_fail_reason(f'incorrect payment secret: {payment_secret_from_onion.hex()=}') raise exc_incorrect_or_unknown_pd + invoice_msat = info.amount_msat if channel_opening_fee: + # deduct just-in-time channel fees from invoice amount invoice_msat -= channel_opening_fee if not (invoice_msat is None or invoice_msat <= total_msat <= 2 * invoice_msat): - log_fail_reason(f"total_msat={total_msat} too different from invoice_msat={invoice_msat}") + _log_fail_reason(f"{total_msat=} too different from {invoice_msat=}") raise exc_incorrect_or_unknown_pd - hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash) - if hold_invoice_callback and not preimage: - callback = lambda: hold_invoice_callback(payment_hash) - return None, (payment_key, callback) + return payment_key - if payment_hash.hex() in self.lnworker.dont_settle_htlcs: - return None, None + def _fulfill_htlc_set(self, payment_key: str, preimage: bytes): + htlc_set = self.lnworker.received_mpp_htlcs[payment_key] + assert len(htlc_set.htlcs) > 0, f"{htlc_set=}" + assert htlc_set.resolution == RecvMPPResolution.SETTLING + assert htlc_set.parent_set_key is None, f"Must not settle child {htlc_set=}" + # get payment hash of any htlc in the set (they are all the same) + payment_hash = htlc_set.get_payment_hash() + assert payment_hash is not None, htlc_set + for mpp_htlc in list(htlc_set.htlcs): + htlc_id = mpp_htlc.htlc.htlc_id + chan = self.lnworker.get_channel_by_short_id(mpp_htlc.scid) + if chan.channel_id not in self.channels: + # this htlc belongs to another peer and has to be settled in their htlc_switch + continue + if not chan.can_update_ctx(proposer=LOCAL): + continue + self.logger.info(f"fulfill htlc: {chan.short_channel_id}. {htlc_id=}. {payment_hash.hex()=}") + if chan.hm.was_htlc_preimage_released(htlc_id=htlc_id, htlc_proposer=REMOTE): + # this check is intended to gracefully handle stale htlcs in the set, e.g. after a crash + self.logger.debug(f"{mpp_htlc=} was already settled before, dropping it.") + htlc_set.htlcs.remove(mpp_htlc) + continue + self._fulfill_htlc(chan, htlc_id, preimage) + htlc_set.htlcs.remove(mpp_htlc) + # reset just-in-time opening fee of channel + chan.opening_fee = None - if not preimage: - if not already_forwarded: - log_fail_reason(f"missing preimage and no hold invoice callback {payment_hash.hex()}") - raise exc_incorrect_or_unknown_pd - else: - return None, None - - chan.opening_fee = None - self.logger.info(f"maybe_fulfill_htlc. will FULFILL HTLC: chan {chan.short_channel_id}. htlc={str(htlc)}") - return preimage, None - - def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): - self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") - assert chan.can_update_ctx(proposer=LOCAL), f"cannot send updates: {chan.short_channel_id}" + def _fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id) self.received_htlcs_pending_removal.add((chan, htlc_id)) chan.settle_htlc(preimage, htlc_id) @@ -2316,6 +2290,61 @@ class Peer(Logger, EventListener): id=htlc_id, payment_preimage=preimage) + def _fail_htlc_set( + self, + payment_key: str, + error_tuple: Tuple[Optional[bytes], Optional[OnionFailureCode | int], Optional[bytes]], + ): + htlc_set = self.lnworker.received_mpp_htlcs[payment_key] + assert htlc_set.resolution in (RecvMPPResolution.FAILED, RecvMPPResolution.EXPIRED) + + raw_error, error_code, error_data = error_tuple + local_height = self.network.blockchain().height() + for mpp_htlc in list(htlc_set.htlcs): + chan = self.lnworker.get_channel_by_short_id(mpp_htlc.scid) + htlc_id = mpp_htlc.htlc.htlc_id + if chan.channel_id not in self.channels: + # this htlc belongs to another peer and has to be settled in their htlc_switch + continue + if not chan.can_update_ctx(proposer=LOCAL): + continue + assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id) + if chan.hm.was_htlc_failed(htlc_id=htlc_id, htlc_proposer=REMOTE): + # this check is intended to gracefully handle stale htlcs in the set, e.g. after a crash + self.logger.debug(f"{mpp_htlc=} was already failed before, dropping it.") + htlc_set.htlcs.remove(mpp_htlc) + continue + onion_packet = self._parse_onion_packet(mpp_htlc.unprocessed_onion) + processed_onion_packet = self._process_incoming_onion_packet( + onion_packet, + payment_hash=mpp_htlc.htlc.payment_hash, + is_trampoline=False, + ) + if raw_error: + error_bytes = obfuscate_onion_error(raw_error, onion_packet.public_key, self.privkey) + else: + assert isinstance(error_code, (OnionFailureCode, int)) + if error_code == OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS: + amount_to_forward = processed_onion_packet.amt_to_forward + # if this was a trampoline htlc we use the inner amount_to_forward as this is + # the value known by the sender + if processed_onion_packet.trampoline_onion_packet: + processed_trampoline_onion_packet = self._process_incoming_onion_packet( + processed_onion_packet.trampoline_onion_packet, + payment_hash=mpp_htlc.htlc.payment_hash, + is_trampoline=True, + ) + amount_to_forward = processed_trampoline_onion_packet.amt_to_forward + error_data = amount_to_forward.to_bytes(8, byteorder="big") + e = OnionRoutingFailure(code=error_code, data=error_data or b'') + error_bytes = e.to_wire_msg(onion_packet, self.privkey, local_height) + self.fail_htlc( + chan=chan, + htlc_id=htlc_id, + error_bytes=error_bytes, + ) + htlc_set.htlcs.remove(mpp_htlc) + def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes): self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") assert chan.can_update_ctx(proposer=LOCAL), f"cannot send updates: {chan.short_channel_id}" @@ -2328,7 +2357,7 @@ class Peer(Logger, EventListener): len=len(error_bytes), reason=error_bytes) - def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure): + def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionParsingError): self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") assert chan.can_update_ctx(proposer=LOCAL), f"cannot send updates: {chan.short_channel_id}" if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32): @@ -2764,79 +2793,107 @@ class Peer(Logger, EventListener): @util.profiler(min_threshold=0.02) def _run_htlc_switch_iteration(self): self._maybe_cleanup_received_htlcs_pending_removal() - # In this loop, an item of chan.unfulfilled_htlcs may go through 4 stages: - # - 1. not forwarded yet: (None, onion_packet_hex) - # - 2. forwarded: (forwarding_key, onion_packet_hex) - # - 3. processed: (forwarding_key, None), not irrevocably removed yet - # - 4. done: (forwarding_key, None), irrevocably removed + # htlc processing happens in two steps: + # 1. Step: Iterating through all channels and their pending htlcs, doing validation + # feasible for single htlcs (some checks only make sense on the whole mpp set) and + # then collecting these htlcs in a mpp set by payment key. + # HTLCs failing these checks will get failed directly and won't be added to any set. + # No htlcs will get settled in this step, settling only happens on complete mpp sets. + # If a new htlc belongs to a set which has already been failed, the htlc will be failed + # and not added to any set. + # Each htlc is only supposed to go through this first loop once when being received. for chan_id, chan in self.channels.items(): if not chan.can_update_ctx(proposer=LOCAL): continue self.maybe_send_commitment(chan) - done = set() unfulfilled = chan.unfulfilled_htlcs - for htlc_id, (onion_packet_hex, forwarding_key) in unfulfilled.items(): + for htlc_id, onion_packet_hex in list(unfulfilled.items()): if not chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): continue + htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id) - if chan.hm.is_htlc_irrevocably_removed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): - assert onion_packet_hex is None - self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc) - if forwarding_key: - self.lnworker.maybe_cleanup_forwarding(forwarding_key) - done.add(htlc_id) - continue - if onion_packet_hex is None: - # has been processed already - continue - error_reason = None # type: Optional[OnionRoutingFailure] - error_bytes = None # type: Optional[bytes] - preimage = None - onion_packet_bytes = bytes.fromhex(onion_packet_hex) - onion_packet = None try: - onion_packet = OnionPacket.from_bytes(onion_packet_bytes) - except OnionRoutingFailure as e: - error_reason = e - else: - try: - preimage, _forwarding_key, error_bytes = self.process_unfulfilled_htlc( - chan=chan, - htlc=htlc, - forwarding_key=forwarding_key, - onion_packet_bytes=onion_packet_bytes, - onion_packet=onion_packet) - if _forwarding_key: - assert forwarding_key is None - unfulfilled[htlc_id] = onion_packet_hex, _forwarding_key - except OnionRoutingFailure as e: - error_bytes = construct_onion_error(e, onion_packet.public_key, self.privkey, self.network.get_local_height()) - if error_bytes: - error_bytes = obfuscate_onion_error(error_bytes, onion_packet.public_key, our_onion_private_key=self.privkey) + onion_packet = self._parse_onion_packet(onion_packet_hex) + except OnionParsingError as e: + self.fail_malformed_htlc( + chan=chan, + htlc_id=htlc.htlc_id, + reason=e, + ) + del unfulfilled[htlc_id] + continue - if preimage or error_reason or error_bytes: - if preimage: - self.lnworker.set_request_status(htlc.payment_hash, PR_PAID) - if not self.lnworker.enable_htlc_settle: - continue - self.fulfill_htlc(chan, htlc.htlc_id, preimage) - elif error_bytes: - self.fail_htlc( - chan=chan, - htlc_id=htlc.htlc_id, - error_bytes=error_bytes) + try: + processed_onion_packet = self._process_incoming_onion_packet( + onion_packet, + payment_hash=htlc.payment_hash, + is_trampoline=False, + ) + payment_key: str = self._check_unfulfilled_htlc( + chan=chan, + htlc=htlc, + processed_onion=processed_onion_packet, + ) + self.lnworker.update_or_create_mpp_with_received_htlc( + payment_key=payment_key, + scid=chan.short_channel_id, + htlc=htlc, + unprocessed_onion_packet=onion_packet_hex, # outer onion if trampoline + ) + except OnionParsingError as e: # could be raised when parsing the inner trampoline onion + self.fail_malformed_htlc( + chan=chan, + htlc_id=htlc.htlc_id, + reason=e, + ) + except Exception as e: + # Fail the htlc directly if it fails to pass these tests, it will not get added to a htlc set. + # https://github.com/lightning/bolts/blob/14272b1bd9361750cfdb3e5d35740889a6b510b5/04-onion-routing.md?plain=1#L388 + reraise = False + if isinstance(e, OnionRoutingFailure): + orf = e else: - self.fail_malformed_htlc( - chan=chan, - htlc_id=htlc.htlc_id, - reason=error_reason) - # blank onion field to mark it as processed - unfulfilled[htlc_id] = None, forwarding_key + orf = OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') + reraise = True # propagate this out, as this might suggest a bug + error_bytes = orf.to_wire_msg(onion_packet, self.privkey, self.network.get_local_height()) + self.fail_htlc( + chan=chan, + htlc_id=htlc.htlc_id, + error_bytes=error_bytes, + ) + if reraise: + raise + finally: + del unfulfilled[htlc_id] - # cleanup - for htlc_id in done: - unfulfilled.pop(htlc_id) - self.maybe_send_commitment(chan) + # 2. Step: Acting on sets of htlcs. + # Doing further checks that have to be done on sets of htlcs (e.g. total amount checks) + # and checks that have to be done continuously like checking for timeout. + # A set marked as failed once must never settle any htlcs associated to it. + # The sets are shared between all peers, so each peers htlc_switch acts on the same sets. + for payment_key, htlc_set in list(self.lnworker.received_mpp_htlcs.items()): + any_error, preimage, callback = self._check_unfulfilled_htlc_set(payment_key, htlc_set) + assert bool(any_error) + bool(preimage) + bool(callback) <= 1, \ + f"{any_error=}, {bool(preimage)=}, {callback=}" + if any_error: + error_tuple = self.lnworker.set_htlc_set_error(payment_key, any_error) + self._fail_htlc_set(payment_key, error_tuple) + if preimage: + if self.lnworker.enable_htlc_settle: + self.lnworker.set_request_status(htlc_set.get_payment_hash(), PR_PAID) + self._fulfill_htlc_set(payment_key, preimage) + if callback: + task = asyncio.create_task(callback()) + task.add_done_callback( # log exceptions occurring in callback + lambda t, pk=payment_key: self.logger.exception( + f"cb failed: " + f"{self.lnworker.received_mpp_htlcs[pk]=}", exc_info=t.exception()) if t.exception() else None + ) + + if len(self.lnworker.received_mpp_htlcs[payment_key].htlcs) == 0: + self.logger.debug(f"deleting resolved mpp set: {payment_key=}") + del self.lnworker.received_mpp_htlcs[payment_key] + self.lnworker.maybe_cleanup_forwarding(payment_key) def _maybe_cleanup_received_htlcs_pending_removal(self) -> None: done = set() @@ -2861,107 +2918,332 @@ class Peer(Logger, EventListener): await group.spawn(htlc_switch_iteration()) await group.spawn(self.got_disconnected.wait()) - def process_unfulfilled_htlc( - self, *, - chan: Channel, - htlc: UpdateAddHtlc, - forwarding_key: Optional[str], - onion_packet_bytes: bytes, - onion_packet: OnionPacket) -> Tuple[Optional[bytes], Optional[str], Optional[bytes]]: - """ - return (preimage, payment_key, error_bytes) with at most a single element that is not None - raise an OnionRoutingFailure if we need to fail the htlc - """ - payment_hash = htlc.payment_hash - processed_onion = self.process_onion_packet( - onion_packet, - payment_hash=payment_hash, - onion_packet_bytes=onion_packet_bytes) + def _log_htlc_fail_reason_cb( + self, + scid: ShortChannelID, + htlc: UpdateAddHtlc, + onion_payload: dict + ) -> Callable[[str], None]: + def _log_fail_reason(reason: str) -> None: + self.logger.info(f"will FAIL HTLC: {str(scid)=}. {reason=}. {str(htlc)=}. {onion_payload=}") + return _log_fail_reason - preimage, forwarding_info = self.maybe_fulfill_htlc( - chan=chan, - htlc=htlc, - processed_onion=processed_onion, - onion_packet_bytes=onion_packet_bytes, - already_forwarded=bool(forwarding_key)) + def _log_htlc_set_fail_reason_cb(self, mpp_set: ReceivedMPPStatus) -> Callable[[str], None]: + def log_fail_reason(reason: str): + for mpp_htlc in mpp_set.htlcs: + try: + processed_onion = self._process_incoming_onion_packet( + onion_packet=self._parse_onion_packet(mpp_htlc.unprocessed_onion), + payment_hash=mpp_htlc.htlc.payment_hash, + is_trampoline=False, + ) + onion_payload = processed_onion.hop_data.payload + except Exception: + onion_payload = {} - if not forwarding_key: - if forwarding_info: - # HTLC we are supposed to forward, but haven't forwarded yet - payment_key, forwarding_callback = forwarding_info + self._log_htlc_fail_reason_cb( + mpp_htlc.scid, + mpp_htlc.htlc, + onion_payload, + )(f"mpp set {id(mpp_set)} failed: {reason}") + + return log_fail_reason + + def _check_unfulfilled_htlc_set( + self, + payment_key: str, + mpp_set: ReceivedMPPStatus + ) -> Tuple[ + Optional[Union[OnionRoutingFailure, OnionFailureCode, bytes]], # error types used to fail the set + Optional[bytes], # preimage to settle the set + Optional[Callable[[], Awaitable[None]]], # callback + ]: + """ + Returns what to do next with the given set of htlcs: + * Fail whole set -> returns error code + * Settle whole set -> Returns preimage + * call callback (e.g. forwarding, hold invoice) + May modify the mpp set in lnworker.received_mpp_htlcs (e.g. by setting its resolution to COMPLETE). + """ + _log_fail_reason = self._log_htlc_set_fail_reason_cb(mpp_set) + + if (final_state := self._check_final_mpp_set_state(payment_key, mpp_set)) is not None: + return final_state + + assert mpp_set.resolution in (RecvMPPResolution.WAITING, RecvMPPResolution.COMPLETE) + chain = self.network.blockchain() + local_height = chain.height() + if chain.is_tip_stale(): + _log_fail_reason(f"our chain tip is stale: {local_height=}") + return OnionFailureCode.TEMPORARY_NODE_FAILURE, None, None + + amount_msat: int = 0 # sum(amount_msat of each htlc) + total_msat = None # type: Optional[int] + payment_hash = mpp_set.get_payment_hash() + closest_cltv_abs = mpp_set.get_closest_cltv_abs() + first_htlc_timestamp = mpp_set.get_first_htlc_timestamp() + processed_onions = {} # type: dict[ReceivedMPPHtlc, Tuple[ProcessedOnionPacket, Optional[ProcessedOnionPacket]]] + for mpp_htlc in mpp_set.htlcs: + processed_onion = self._process_incoming_onion_packet( + onion_packet=self._parse_onion_packet(mpp_htlc.unprocessed_onion), + payment_hash=payment_hash, + is_trampoline=False, # this is always the outer onion + ) + processed_onions[mpp_htlc] = (processed_onion, None) + inner_onion = None + if processed_onion.trampoline_onion_packet: + inner_onion = self._process_incoming_onion_packet( + onion_packet=processed_onion.trampoline_onion_packet, + payment_hash=payment_hash, + is_trampoline=True, + ) + processed_onions[mpp_htlc] = (processed_onion, inner_onion) + + total_msat_outer_onion = processed_onion.total_msat + total_msat_inner_onion = inner_onion.total_msat if inner_onion else None + if total_msat is None: + total_msat = total_msat_inner_onion or total_msat_outer_onion + + # check total_msat is equal for all htlcs of the set + if total_msat != (total_msat_inner_onion or total_msat_outer_onion): + _log_fail_reason(f"total_msat is not uniform: {total_msat=} != {processed_onion.total_msat=}") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + + amount_msat += mpp_htlc.htlc.amount_msat + + # If the set contains outer onions with different payment secrets, the set's payment_key is + # derived from the trampoline/invoice/inner payment secret, so it is the second stage of a + # multi-trampoline payment in which all the trampoline parts/htlcs got combined. + # In this case the amt_to_forward cannot be compared as it may differ between the trampoline parts. + # However, amt_to_forward should be similar for all onions of a single trampoline part and gets + # compared in the first stage where the htlc set represents a single trampoline part. + outer_onions = [onions[0] for onions in processed_onions.values()] + can_have_different_amt_to_fwd = not all(o.payment_secret == outer_onions[0].payment_secret for o in outer_onions) + trampoline_onions = iter(onions[1] for onions in processed_onions.values()) + if not lnonion.compare_trampoline_onions(trampoline_onions, exclude_amt_to_fwd=can_have_different_amt_to_fwd): + _log_fail_reason(f"got inconsistent {trampoline_onions=}") + return OnionFailureCode.INVALID_ONION_PAYLOAD, None, None + + if len(processed_onions) == 1: + outer_onion, inner_onion = next(iter(processed_onions.values())) + if not outer_onion.are_we_final: + assert inner_onion is None, f"{outer_onion=}\n{inner_onion=}" if not self.lnworker.enable_htlc_forwarding: return None, None, None - if payment_key not in self.lnworker.active_forwardings: - async def wrapped_callback(): - forwarding_coro = forwarding_callback() - try: - next_htlc = await forwarding_coro - if next_htlc: - htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id) - self.lnworker.active_forwardings[payment_key].append(next_htlc) - self.lnworker.downstream_to_upstream_htlc[next_htlc] = htlc_key - except OnionRoutingFailure as e: - if len(self.lnworker.active_forwardings[payment_key]) == 0: - self.lnworker.save_forwarding_failure(payment_key, failure_message=e) - # TODO what about other errors? e.g. TxBroadcastError for a swap. - # - malicious electrum server could fake TxBroadcastError - # Could we "catch-all Exception" and fail back the htlcs with e.g. TEMPORARY_NODE_FAILURE? - # - we don't want to fail the inc-HTLC for a syntax error that happens in the callback - # If we don't call save_forwarding_failure(), the inc-HTLC gets stuck until expiry - # and then the inc-channel will get force-closed. - # => forwarding_callback() could have an API with two exceptions types: - # - type1, such as OnionRoutingFailure, that signals we need to fail back the inc-HTLC - # - type2, such as TxBroadcastError, that signals we want to retry the callback - # add to list - assert len(self.lnworker.active_forwardings.get(payment_key, [])) == 0 - self.lnworker.active_forwardings[payment_key] = [] - fut = asyncio.ensure_future(wrapped_callback()) - # return payment_key so this branch will not be executed again - return None, payment_key, None - elif preimage: - return preimage, None, None - else: - # we are waiting for mpp consolidation or preimage + # this is a single (non-trampoline) htlc set which needs to be forwarded. + # set to settling state so it will not be failed or forwarded twice. + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.SETTLING) + fwd_cb = lambda: self.lnworker.maybe_forward_htlc_set(payment_key, processed_htlc_set=processed_onions) + return None, None, fwd_cb + + assert payment_hash is not None and total_msat is not None + # check for expiry over time and potentially fail the whole set if any + # htlc's cltv becomes too close + blocks_to_expiry = max(0, closest_cltv_abs - local_height) + if blocks_to_expiry < MIN_FINAL_CLTV_DELTA_ACCEPTED: + _log_fail_reason(f"htlc.cltv_abs is unreasonably close") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + + # check for mpp expiry (if incomplete and expired -> fail) + if mpp_set.resolution == RecvMPPResolution.WAITING \ + or not self.lnworker.is_payment_bundle_complete(payment_key): + # maybe this set is COMPLETE but the bundle is not yet completed, so the bundle can be considered WAITING + if int(time.time()) - first_htlc_timestamp > self.lnworker.MPP_EXPIRY \ + or self.lnworker.stopping_soon: + _log_fail_reason(f"MPP TIMEOUT (> {self.lnworker.MPP_EXPIRY} sec)") + return OnionFailureCode.MPP_TIMEOUT, None, None + + if mpp_set.resolution == RecvMPPResolution.WAITING: + # check if set is first stage multi-trampoline payment to us + # first stage trampoline payment: + # is a trampoline payment + we_are_final + payment key is derived from outer onion's payment secret + # (so it is not the payment secret we requested in the invoice, but some secret set by a + # trampoline forwarding node on the route). + # if it is first stage, check if sum(htlcs) >= amount_to_forward of the trampoline_payload. + # If this part is complete, move the htlcs to the overall mpp set of the payment (keyed by inner secret). + # Once the second stage set (the set containing all htlcs of the separate trampoline parts) + # is complete, the payment gets fulfilled. + trampoline_payment_key = None + any_trampoline_onion = next(iter(processed_onions.values()))[1] + if any_trampoline_onion and any_trampoline_onion.are_we_final: + trampoline_payment_secret = any_trampoline_onion.payment_secret + assert trampoline_payment_secret == self.lnworker.get_payment_secret(payment_hash) + trampoline_payment_key = (payment_hash + trampoline_payment_secret).hex() + + if trampoline_payment_key and trampoline_payment_key != payment_key: + # first stage of trampoline payment, the first stage must never get set COMPLETE + if amount_msat >= any_trampoline_onion.amt_to_forward: + # setting the parent key will mark the htlcs to be moved to the parent set + self.logger.debug(f"trampoline part complete. {len(mpp_set.htlcs)=}, " + f"{amount_msat=}. setting parent key: {trampoline_payment_key}") + self.lnworker.received_mpp_htlcs[payment_key] = mpp_set._replace( + parent_set_key=trampoline_payment_key, + ) + elif amount_msat >= total_msat: + # set mpp_set as completed as we have received the full total_msat + mpp_set = self.lnworker.set_mpp_resolution( + payment_key=payment_key, + new_resolution=RecvMPPResolution.COMPLETE, + ) + + # check if this set is a trampoline forwarding and potentially return forwarding callback + # note: all inner trampoline onions are equal (enforced above) + _, any_inner_onion = next(iter(processed_onions.values())) + if any_inner_onion and not any_inner_onion.are_we_final: + # this is a trampoline forwarding + can_forward = mpp_set.resolution == RecvMPPResolution.COMPLETE and self.lnworker.enable_htlc_forwarding + if not can_forward: return None, None, None - else: - # HTLC we are supposed to forward, and have already forwarded - # for final trampoline onions, forwarding failures are stored with forwarding_key (which is the inner key) - payment_key = forwarding_key - preimage = self.lnworker.get_preimage(payment_hash) - error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key) - if error_bytes: - return None, None, error_bytes - if error_reason: - raise error_reason - if preimage: - return preimage, None, None + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.SETTLING) + fwd_cb = lambda: self.lnworker.maybe_forward_htlc_set(payment_key, processed_htlc_set=processed_onions) + return None, None, fwd_cb + + # -- from here on it's assumed this set is a payment for us (not something to forward) -- + payment_info = self.lnworker.get_payment_info(payment_hash) + if payment_info is None: + _log_fail_reason(f"payment info has been deleted") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + + # check invoice expiry, fail set if the invoice has expired before it was completed + if mpp_set.resolution == RecvMPPResolution.WAITING: + if int(time.time()) > payment_info.expiration_ts: + _log_fail_reason(f"invoice is expired {payment_info.expiration_ts=}") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None return None, None, None - def process_onion_packet( + if payment_hash.hex() in self.lnworker.dont_settle_htlcs: + # used by hold invoice cli to prevent the htlcs from getting fulfilled automatically + return None, None, None + + preimage = self.lnworker.get_preimage(payment_hash) + hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash) + if not preimage and not hold_invoice_callback: + _log_fail_reason(f"cannot settle, no preimage or callback found for {payment_hash.hex()=}") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + + if not self.lnworker.is_payment_bundle_complete(payment_key): + # don't allow settling before all sets of the bundle are COMPLETE + return None, None, None + else: + # If this set is part of a bundle now all parts are COMPLETE so the bundle can be deleted + # so the individual sets will get fulfilled. + self.lnworker.delete_payment_bundle(payment_key=bytes.fromhex(payment_key)) + + assert mpp_set.resolution == RecvMPPResolution.COMPLETE, "should return earlier if set is incomplete" + if not preimage: + assert hold_invoice_callback is not None, "should have been failed before" + async def callback(): + try: + await hold_invoice_callback(payment_hash) + except OnionRoutingFailure as e: # todo: should this catch all exceptions? + _log_fail_reason(f"hold invoice callback raised {e}") + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.FAILED) + # mpp set must not be failed unless the consumer calls unregister_hold_invoice and + # callback must only be called once. This is enforced by setting the set to SETTLING. + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.SETTLING) + return None, None, callback + + # settle htlc set + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.SETTLING) + return None, preimage, None + + def _check_final_mpp_set_state( + self, + payment_key: str, + mpp_set: ReceivedMPPStatus, + ) -> Optional[Tuple[ + Optional[Union[OnionRoutingFailure, OnionFailureCode, bytes]], # error types used to fail the set + Optional[bytes], # preimage to settle the set + None, # callback + ]]: + """ + handle sets that are already in a state eligible for fulfillment or failure and shouldn't + go through another iteration of _check_unfulfilled_htlc_set. + """ + if len(mpp_set.htlcs) == 0: + # stale set, will get deleted on the next iteration + return None, None, None + + if mpp_set.resolution == RecvMPPResolution.FAILED: + error_bytes, failure_message = self.lnworker.get_forwarding_failure(payment_key) + if error_bytes or failure_message: + return error_bytes or failure_message, None, None + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + elif mpp_set.resolution == RecvMPPResolution.EXPIRED: + return OnionFailureCode.MPP_TIMEOUT, None, None + + if mpp_set.parent_set_key: + # this is a complete trampoline part of a multi trampoline payment. Move the htlcs to parent. + parent = self.lnworker.received_mpp_htlcs.get(mpp_set.parent_set_key) + if not parent: + parent = ReceivedMPPStatus( + resolution=RecvMPPResolution.WAITING, + htlcs=set(), + ) + self.lnworker.received_mpp_htlcs[mpp_set.parent_set_key] = parent + parent.htlcs.update(mpp_set.htlcs) + mpp_set.htlcs.clear() + return None, None, None # this set will get deleted as there are no htlcs in it anymore + + assert not mpp_set.parent_set_key + if mpp_set.resolution == RecvMPPResolution.SETTLING: + # this is an ongoing forwarding, or a set that has not yet been fully settled (and removed). + # note the htlcs in SETTLING will not get failed automatically, + # even if timeout comes close, so either a forwarding failure or preimage has to be set + error_bytes, failure_message = self.lnworker.get_forwarding_failure(payment_key) + if error_bytes or failure_message: + # this was a forwarding set and it failed + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.FAILED) + return error_bytes or failure_message, None, None + preimage = self.lnworker.get_preimage(mpp_set.get_payment_hash()) + return None, preimage, None + + return None + + def _parse_onion_packet(self, onion_packet_hex: str) -> OnionPacket: + """ + https://github.com/lightning/bolts/blob/14272b1bd9361750cfdb3e5d35740889a6b510b5/02-peer-protocol.md?plain=1#L2352 + """ + onion_packet_bytes = None + try: + onion_packet_bytes = bytes.fromhex(onion_packet_hex) + onion_packet = OnionPacket.from_bytes(onion_packet_bytes) + except Exception as parsing_exc: + self.logger.warning(f"unable to parse onion: {str(parsing_exc)}") + onion_parsing_error = OnionParsingError( + code=OnionFailureCodeMetaFlag.BADONION, + data=sha256(onion_packet_bytes or b''), + ) + raise onion_parsing_error + return onion_packet + + def _process_incoming_onion_packet( self, onion_packet: OnionPacket, *, payment_hash: bytes, - onion_packet_bytes: bytes, is_trampoline: bool = False) -> ProcessedOnionPacket: - - failure_data = sha256(onion_packet_bytes) + onion_hash = onion_packet.onion_hash + cache_key = sha256(onion_hash + payment_hash + bytes([is_trampoline])) # type: ignore + if cached_onion := self._processed_onion_cache.get(cache_key): + return cached_onion try: - processed_onion = process_onion_packet( + processed_onion = lnonion.process_onion_packet( onion_packet, our_onion_private_key=self.privkey, associated_data=payment_hash, is_trampoline=is_trampoline) + self._processed_onion_cache[cache_key] = processed_onion except UnsupportedOnionPacketVersion: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=failure_data) + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=onion_hash) except InvalidOnionPubkey: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_KEY, data=failure_data) + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_KEY, data=onion_hash) except InvalidOnionMac: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_HMAC, data=failure_data) + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_HMAC, data=onion_hash) except Exception as e: - self.logger.info(f"error processing onion packet: {e!r}") - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=failure_data) + self.logger.warning(f"error processing onion packet: {e!r}") + raise OnionParsingError(code=OnionFailureCodeMetaFlag.BADONION, data=onion_hash) if self.network.config.TEST_FAIL_HTLCS_AS_MALFORMED: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=failure_data) + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=onion_hash) if self.network.config.TEST_FAIL_HTLCS_WITH_TEMP_NODE_FAILURE: raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') return processed_onion diff --git a/electrum/lnsweep.py b/electrum/lnsweep.py index 014c87ad4..9fededb8a 100644 --- a/electrum/lnsweep.py +++ b/electrum/lnsweep.py @@ -445,6 +445,8 @@ def sweep_our_ctx( if not preimage: # we might not have the preimage if this is a hold invoice continue + if htlc.payment_hash in chan.lnworker.dont_settle_htlcs: + continue else: preimage = None try: @@ -746,6 +748,8 @@ def sweep_their_ctx( if not preimage: # we might not have the preimage if this is a hold invoice continue + if htlc.payment_hash in chan.lnworker.dont_settle_htlcs: + continue else: preimage = None tx_htlc( diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 76a6c876b..28f1f4e36 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -1934,26 +1934,86 @@ class UpdateAddHtlc: # Note: these states are persisted in the wallet file. # Do not modify them without performing a wallet db upgrade +# todo: if this changes again states could also be persisted by name instead of int value as done for ChannelState class RecvMPPResolution(IntEnum): - WAITING = 0 - EXPIRED = 1 - COMPLETE = 2 - FAILED = 3 + WAITING = 0 # set is not complete yet, waiting for arrival of the remaining htlcs + EXPIRED = 1 # preimage must not be revealed + COMPLETE = 2 # set is complete but could still be failed (e.g. due to cltv timeout) + FAILED = 3 # preimage must not be revealed + SETTLING = 4 # Must not be failed, should be settled asap. + # Also used when forwarding (for upstream), in which case a downstream + # forwarding failure could still result in transitioning to FAILED. + + +r = RecvMPPResolution +allowed_mpp_set_transitions = ( + (r.WAITING, r.EXPIRED), + (r.WAITING, r.FAILED), + (r.WAITING, r.COMPLETE), + (r.WAITING, r.SETTLING), # normal htlc forwarding + + (r.COMPLETE, r.SETTLING), + (r.COMPLETE, r.FAILED), + (r.COMPLETE, r.EXPIRED), # this should only realistically happen for payment bundles + + (r.SETTLING, r.FAILED), # forwarding failure, hold invoice callback gets unregistered, and we don't have preimage + + (r.EXPIRED, r.FAILED), # doesn't seem useful but also not dangerous +) +del r + + +class ReceivedMPPHtlc(NamedTuple): + scid: ShortChannelID + htlc: UpdateAddHtlc + unprocessed_onion: str + + def __repr__(self): + return f"{self.scid}, {self.htlc=}, {self.unprocessed_onion[:15]=}..." + + @staticmethod + def from_tuple(scid, htlc, unprocessed_onion) -> 'ReceivedMPPHtlc': + assert is_hex_str(unprocessed_onion) and is_hex_str(scid) + return ReceivedMPPHtlc( + scid=ShortChannelID(bytes.fromhex(scid)), + htlc=UpdateAddHtlc.from_tuple(*htlc), + unprocessed_onion=unprocessed_onion, + ) class ReceivedMPPStatus(NamedTuple): resolution: RecvMPPResolution - expected_msat: int - htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] + htlcs: set[ReceivedMPPHtlc] + # parent_set_key is needed as trampoline allows MPP to be nested, the parent_set_key is the + # payment key of the final mpp set (derived from inner trampoline onion payment secret) + # to which the separate trampoline sets htlcs get added once they are complete. + # https://github.com/lightning/bolts/pull/829/commits/bc7a1a0bc97b2293e7f43dd8a06529e5fdcf7cd2 + parent_set_key: str = None + + def get_first_htlc_timestamp(self) -> Optional[int]: + return min([mpp_htlc.htlc.timestamp for mpp_htlc in self.htlcs], default=None) + + def get_closest_cltv_abs(self) -> Optional[int]: + return min([mpp_htlc.htlc.cltv_abs for mpp_htlc in self.htlcs], default=None) + + def get_payment_hash(self) -> Optional[bytes]: + mpp_htlcs = iter(self.htlcs) + first_mpp_htlc = next(mpp_htlcs, None) + payment_hash = first_mpp_htlc.htlc.payment_hash if first_mpp_htlc else None + for mpp_htlc in mpp_htlcs: + assert mpp_htlc.htlc.payment_hash == payment_hash, "mpp set with inconsistent payment hashes" + return payment_hash @staticmethod @stored_in('received_mpp_htlcs', tuple) - def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus': - htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid, x) in htlc_list]) + def from_tuple(resolution, htlc_list, parent_set_key=None) -> 'ReceivedMPPStatus': + assert isinstance(resolution, int) + htlc_set = set(ReceivedMPPHtlc.from_tuple(*htlc_data) for htlc_data in htlc_list) return ReceivedMPPStatus( resolution=RecvMPPResolution(resolution), - expected_msat=expected_msat, - htlc_set=htlc_set) + htlcs=htlc_set, + parent_set_key=parent_set_key, + ) class OnionFailureCodeMetaFlag(IntFlag): diff --git a/electrum/lnworker.py b/electrum/lnworker.py index d3cf02e0a..d4c5bc8cb 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -10,7 +10,7 @@ import time from enum import IntEnum from typing import ( Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Mapping, Any, Iterable, AsyncGenerator, - Callable, Awaitable + Callable, Awaitable, Union, ) from types import MappingProxyType import threading @@ -70,7 +70,7 @@ from .lnutil import ( ShortChannelID, HtlcLog, NoPathFound, InvalidGossipMsg, FeeBudgetExceeded, ImportedChannelBackupStorage, OnchainChannelBackupStorage, ln_compare_features, IncompatibleLightningFeatures, PaymentFeeBudget, NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage, MIN_FUNDING_SAT, - MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE, ReceivedMPPStatus, RecvMPPResolution, + MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE, RecvMPPResolution, ReceivedMPPStatus, ReceivedMPPHtlc, PaymentSuccess, ) from .lnonion import ( @@ -1237,6 +1237,8 @@ class LNWallet(LNWorker): if type(chan) is Channel: self.save_channel(chan) self.clear_invoices_cache() + if chan._state == ChannelState.REDEEMED: + self.maybe_cleanup_mpp(chan) util.trigger_callback('channel', self.wallet, chan) def save_channel(self, chan: Channel): @@ -2399,15 +2401,47 @@ class LNWallet(LNWorker): self._payment_bundles_pkey_to_canon[pkey] = canon_pkey self._payment_bundles_canon_to_pkeylist[canon_pkey] = tuple(payment_keys) - def get_payment_bundle(self, payment_key: bytes) -> Sequence[bytes]: + def get_payment_bundle(self, payment_key: Union[bytes, str]) -> Sequence[bytes]: with self.lock: + if isinstance(payment_key, str): + try: + payment_key = bytes.fromhex(payment_key) + except ValueError: + # might be a forwarding payment_key which is not hex and will never have a bundle + return [] canon_pkey = self._payment_bundles_pkey_to_canon.get(payment_key) if canon_pkey is None: return [] return self._payment_bundles_canon_to_pkeylist[canon_pkey] - def delete_payment_bundle(self, payment_hash: bytes) -> None: - payment_key = self._get_payment_key(payment_hash) + def is_payment_bundle_complete(self, any_payment_key: str) -> bool: + """ + complete means a htlc set is available for each payment key of the payment bundle and + all htlc sets have a resolution >= COMPLETE (we got the whole payment bundle amount) + """ + # get all payment keys covered by this bundle + bundle_payment_keys = self.get_payment_bundle(any_payment_key) + if not bundle_payment_keys: # there is no payment bundle + return True + for payment_key in bundle_payment_keys: + mpp_set = self.received_mpp_htlcs.get(payment_key.hex()) + if mpp_set is None: + # payment bundle is missing htlc set for payment request + # it might have already been failed and deleted + return False + elif mpp_set.resolution not in (RecvMPPResolution.COMPLETE, RecvMPPResolution.SETTLING): + return False + return True + + def delete_payment_bundle( + self, *, + payment_hash: Optional[bytes] = None, + payment_key: Optional[bytes] = None, + ) -> None: + assert (payment_hash is not None) ^ (payment_key is not None), \ + "must provide exactly one of (payment_hash, payment_key)" + if not payment_key: + payment_key = self._get_payment_key(payment_hash) with self.lock: canon_pkey = self._payment_bundles_pkey_to_canon.get(payment_key) if canon_pkey is None: # is it ok for bundle to be missing?? @@ -2478,10 +2512,16 @@ class LNWallet(LNWorker): self.save_payment_info(info, write_to_disk=False) def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]): + assert self.get_preimage(payment_hash) is None, "hold invoice cb won't get called if preimage is already set" self.hold_invoice_callbacks[payment_hash] = cb def unregister_hold_invoice(self, payment_hash: bytes): - self.hold_invoice_callbacks.pop(payment_hash) + self.hold_invoice_callbacks.pop(payment_hash, None) + payment_key = self._get_payment_key(payment_hash).hex() + if payment_key in self.received_mpp_htlcs: + if self.get_preimage(payment_hash) is None: + # the pending mpp set can be failed as we don't have the preimage to settle it + self.set_mpp_resolution(payment_key, RecvMPPResolution.FAILED) def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None: assert info.status in SAVED_PR_STATUS @@ -2500,132 +2540,142 @@ class LNWallet(LNWorker): if write_to_disk: self.wallet.save_db() - def check_mpp_status( - self, *, - payment_secret: bytes, - short_channel_id: ShortChannelID, - htlc: UpdateAddHtlc, - expected_msat: int, - ) -> RecvMPPResolution: - """Returns the status of the incoming htlc set the given *htlc* belongs to. - - ACCEPTED simply means the mpp set is complete, and we can proceed with further - checks before fulfilling (or failing) the htlcs. - In particular, note that hold-invoice-htlcs typically remain in the ACCEPTED state - for quite some time -- not in the "WAITING" state (which would refer to the mpp set - not yet being complete!). - """ - payment_hash = htlc.payment_hash - payment_key = payment_hash + payment_secret - self.update_mpp_with_received_htlc( - payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat) - mpp_resolution = self.received_mpp_htlcs[payment_key.hex()].resolution - # if still waiting, calc resolution now: - if mpp_resolution == RecvMPPResolution.WAITING: - bundle = self.get_payment_bundle(payment_key) - if bundle: - payment_keys = bundle - else: - payment_keys = [payment_key] - first_timestamp = min([self.get_first_timestamp_of_mpp(pkey) for pkey in payment_keys]) - if self.get_payment_status(payment_hash) == PR_PAID: - mpp_resolution = RecvMPPResolution.COMPLETE - elif self.stopping_soon: - # try to time out pending HTLCs before shutting down - mpp_resolution = RecvMPPResolution.EXPIRED - elif all([self.is_mpp_amount_reached(pkey) for pkey in payment_keys]): - mpp_resolution = RecvMPPResolution.COMPLETE - elif time.time() - first_timestamp > self.MPP_EXPIRY: - mpp_resolution = RecvMPPResolution.EXPIRED - # save resolution, if any. - if mpp_resolution != RecvMPPResolution.WAITING: - for pkey in payment_keys: - if pkey.hex() in self.received_mpp_htlcs: - self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution) - - return mpp_resolution - - def update_mpp_with_received_htlc( + def update_or_create_mpp_with_received_htlc( self, *, - payment_key: bytes, + payment_key: str, scid: ShortChannelID, htlc: UpdateAddHtlc, - expected_msat: int, + unprocessed_onion_packet: str, ): - # add new htlc to set - mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) + # Payment key creation: + # * for regular forwarded htlcs -> "scid.hex() + ':%d' % htlc_id" [htlc key] + # * for trampoline forwarding -> "payment hash + payment secret from outer onion" + # * for final non-trampoline htlcs (we are receiver) -> "payment hash + payment secret from onion" + # * for final trampoline htlcs (we are receiver) -> 2. step grouping: + # 1. grouping of htlcs by "payments hash + outer onion payment secret", a 'multi-trampoline mpp part'. + # 2. once the set of step 1. is COMPLETE (amount_fwd outer onion >= total_amt outer onion) + # the htlcs get moved to the parent mpp set (created once first part is complete) grouped by: + # "payment_hash + inner onion payment secret (the one in the invoice)" + # After moving the htlcs the first set gets deleted. + # + # Add the validated htlc to the htlc set associated with the payment key. + # If no set exists, a new set in WAITING state is created. + mpp_status = self.received_mpp_htlcs.get(payment_key) if mpp_status is None: + self.logger.debug(f"creating new mpp set for {payment_key=}") mpp_status = ReceivedMPPStatus( resolution=RecvMPPResolution.WAITING, - expected_msat=expected_msat, - htlc_set=set(), + htlcs=set(), ) - if expected_msat != mpp_status.expected_msat: - self.logger.info( - f"marking received mpp as failed. inconsistent total_msats in bucket. {payment_key.hex()=}") - mpp_status = mpp_status._replace(resolution=RecvMPPResolution.FAILED) - key = (scid, htlc) - if key not in mpp_status.htlc_set: - mpp_status.htlc_set.add(key) # side-effecting htlc_set - self.received_mpp_htlcs[payment_key.hex()] = mpp_status - def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution): - mpp_status = self.received_mpp_htlcs[payment_key.hex()] - self.logger.info(f'set_mpp_resolution {resolution.name} {len(mpp_status.htlc_set)} {payment_key.hex()}') - self.received_mpp_htlcs[payment_key.hex()] = mpp_status._replace(resolution=resolution) + if mpp_status.resolution > RecvMPPResolution.WAITING: + # we are getting a htlc for a set that is not in WAITING state, it cannot be safely added + self.logger.info(f"htlc set cannot accept htlc, failing htlc: {scid=} {htlc.htlc_id=}") + if mpp_status == RecvMPPResolution.EXPIRED: + raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') + raise OnionRoutingFailure( + code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, + data=htlc.amount_msat.to_bytes(8, byteorder="big"), + ) - def is_mpp_amount_reached(self, payment_key: bytes) -> bool: - amounts = self.get_mpp_amounts(payment_key) - if amounts is None: - return False - total, expected = amounts - return total >= expected + new_htlc = ReceivedMPPHtlc( + scid=scid, + htlc=htlc, + unprocessed_onion=unprocessed_onion_packet, + ) + assert new_htlc not in mpp_status.htlcs, "each htlc should make it here only once?" + assert isinstance(unprocessed_onion_packet, str) + mpp_status.htlcs.add(new_htlc) # side-effecting htlc_set + self.received_mpp_htlcs[payment_key] = mpp_status - def is_complete_mpp(self, payment_hash: bytes) -> bool: + def set_mpp_resolution(self, payment_key: str, new_resolution: RecvMPPResolution) -> ReceivedMPPStatus: + mpp_status = self.received_mpp_htlcs[payment_key] + if mpp_status.resolution == new_resolution: + return mpp_status + if not (mpp_status.resolution, new_resolution) in lnutil.allowed_mpp_set_transitions: + raise ValueError(f'forbidden mpp set transition: {mpp_status.resolution} -> {new_resolution}') + self.logger.info(f'set_mpp_resolution {new_resolution.name} {len(mpp_status.htlcs)=}: {payment_key=}') + self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=new_resolution) + self.wallet.save_db() + return self.received_mpp_htlcs[payment_key] + + def set_htlc_set_error( + self, + payment_key: str, + error: Union[bytes, OnionFailureCode, OnionRoutingFailure], + ) -> Optional[Tuple[Optional[bytes], Optional[OnionFailureCode | int], Optional[bytes]]]: + """ + handles different types of errors and sets the htlc set to failed, then returns a more + structured tuple of error types which can then be used to fail the htlc set + """ + htlc_set = self.received_mpp_htlcs[payment_key] + assert htlc_set.resolution != RecvMPPResolution.SETTLING + raw_error, error_code, error_data = None, None, None + if isinstance(error, bytes): + raw_error = error + elif isinstance(error, OnionFailureCode): + error_code = error + elif isinstance(error, OnionRoutingFailure): + error_code, error_data = OnionFailureCode.from_int(error.code), error.data + else: + raise ValueError(f"invalid error type: {repr(error)}") + + if error_code == OnionFailureCode.MPP_TIMEOUT: + self.set_mpp_resolution(payment_key=payment_key, new_resolution=RecvMPPResolution.EXPIRED) + else: + self.set_mpp_resolution(payment_key=payment_key, new_resolution=RecvMPPResolution.FAILED) + + return raw_error, error_code, error_data + + def get_mpp_resolution(self, payment_hash: bytes) -> Optional[RecvMPPResolution]: payment_key = self._get_payment_key(payment_hash) status = self.received_mpp_htlcs.get(payment_key.hex()) - return status and status.resolution == RecvMPPResolution.COMPLETE + return status.resolution if status else None + + def is_complete_mpp(self, payment_hash: bytes) -> bool: + resolution = self.get_mpp_resolution(payment_hash) + if resolution is not None: + return resolution in (RecvMPPResolution.COMPLETE, RecvMPPResolution.SETTLING) + return False def get_payment_mpp_amount_msat(self, payment_hash: bytes) -> Optional[int]: """Returns the received mpp amount for given payment hash.""" payment_key = self._get_payment_key(payment_hash) - amounts = self.get_mpp_amounts(payment_key) - if not amounts: + total_msat = self.get_mpp_amounts(payment_key) + if not total_msat: return None - total_msat, _ = amounts return total_msat - def get_mpp_amounts(self, payment_key: bytes) -> Optional[Tuple[int, int]]: - """Returns (total received amount, expected amount) or None.""" + def get_mpp_amounts(self, payment_key: bytes) -> Optional[int]: + """Returns total received amount or None.""" mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) if not mpp_status: return None - total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set]) - return total, mpp_status.expected_msat - - def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int: - mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) - if not mpp_status: - return int(time.time()) - return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set]) + total = sum([mpp_htlc.htlc.amount_msat for mpp_htlc in mpp_status.htlcs]) + return total def maybe_cleanup_mpp( self, - short_channel_id: ShortChannelID, - htlc: UpdateAddHtlc, + chan: Channel, ) -> None: - - htlc_key = (short_channel_id, htlc) + """ + Remove all remaining mpp htlcs of the given channel after closing. + Usually they get removed in htlc_switch after all htlcs of the set are resolved, + however if there is a force close with pending htlcs they need to be removed after the channel + is closed. + """ + # only cleanup when channel is REDEEMED as mpp set is still required for lnsweep + assert chan._state == ChannelState.REDEEMED for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()): - if htlc_key not in mpp_status.htlc_set: - continue - assert mpp_status.resolution != RecvMPPResolution.WAITING - self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}') - mpp_status.htlc_set.remove(htlc_key) # side-effecting htlc_set - if len(mpp_status.htlc_set) == 0: + htlcs_to_remove = [htlc for htlc in mpp_status.htlcs if htlc.scid == chan.short_channel_id] + for stale_mpp_htlc in htlcs_to_remove: + assert mpp_status.resolution != RecvMPPResolution.WAITING + self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}') + mpp_status.htlcs.remove(stale_mpp_htlc) # side-effecting htlc_set + if len(mpp_status.htlcs) == 0: self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}') - self.received_mpp_htlcs.pop(payment_key_hex) + del self.received_mpp_htlcs[payment_key_hex] self.maybe_cleanup_forwarding(payment_key_hex) def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None: @@ -2681,6 +2731,7 @@ class LNWallet(LNWorker): for payment_key, htlcs in self.active_forwardings.items(): if htlc_key in htlcs: return payment_key + return None def notify_upstream_peer(self, htlc_key: str) -> None: """Called when an HTLC we offered on chan gets irrevocably fulfilled or failed. @@ -3436,7 +3487,58 @@ class LNWallet(LNWorker): util.trigger_callback('channels_updated', self.wallet) self.lnwatcher.add_channel(cb) - async def maybe_forward_htlc( + async def maybe_forward_htlc_set( + self, + payment_key: str, *, + processed_htlc_set: dict[ReceivedMPPHtlc, Tuple[ProcessedOnionPacket, Optional[ProcessedOnionPacket]]], + ) -> None: + assert self.enable_htlc_forwarding + assert payment_key not in self.active_forwardings, "cannot forward set twice" + self.active_forwardings[payment_key] = [] + self.logger.debug(f"adding active_forwarding: {payment_key=}") + + any_mpp_htlc, (any_outer_onion, any_trampoline_onion) = next(iter(processed_htlc_set.items())) + try: + if any_trampoline_onion is None: + assert not any_outer_onion.are_we_final + assert len(processed_htlc_set) == 1, processed_htlc_set + forward_htlc = any_mpp_htlc.htlc + incoming_chan = self.get_channel_by_short_id(any_mpp_htlc.scid) + next_htlc = await self._maybe_forward_htlc( + incoming_chan=incoming_chan, + htlc=forward_htlc, + processed_onion=any_outer_onion, + ) + htlc_key = serialize_htlc_key(incoming_chan.get_scid_or_local_alias(), forward_htlc.htlc_id) + self.active_forwardings[payment_key].append(next_htlc) + self.downstream_to_upstream_htlc[next_htlc] = htlc_key + else: + assert not any_trampoline_onion.are_we_final and any_outer_onion.are_we_final + # trampoline forwarding + min_inc_cltv_abs = min( + mpp_htlc.htlc.cltv_abs + for mpp_htlc in processed_htlc_set.keys()) # take "min" to assume worst-case + await self._maybe_forward_trampoline( + payment_hash=any_mpp_htlc.htlc.payment_hash, + closest_inc_cltv_abs=min_inc_cltv_abs, + total_msat=any_outer_onion.total_msat, + any_trampoline_onion=any_trampoline_onion, + fw_payment_key=payment_key, + ) + except OnionRoutingFailure as e: + self.logger.debug(f"forwarding failed: {e=}") + if len(self.active_forwardings[payment_key]) == 0: + self.save_forwarding_failure(payment_key, failure_message=e) + # TODO what about other errors? + # Could we "catch-all Exception" and fail back the htlcs with e.g. TEMPORARY_NODE_FAILURE? + # - we don't want to fail the inc-HTLC for a syntax error that happens in the callback + # If we don't call save_forwarding_failure(), the inc-HTLC gets stuck until expiry + # and then the inc-channel will get force-closed. + # => forwarding_callback() could have an API with two exceptions types: + # - type1, such as OnionRoutingFailure, that signals we need to fail back the inc-HTLC + # - type2, such as NoPathFound, that signals we want to retry forwarding + + async def _maybe_forward_htlc( self, *, incoming_chan: Channel, htlc: UpdateAddHtlc, @@ -3550,13 +3652,12 @@ class LNWallet(LNWorker): htlc_key = serialize_htlc_key(next_chan.get_scid_or_local_alias(), next_htlc.htlc_id) return htlc_key - @log_exceptions - async def maybe_forward_trampoline( + async def _maybe_forward_trampoline( self, *, payment_hash: bytes, - inc_cltv_abs: int, - outer_onion: ProcessedOnionPacket, - trampoline_onion: ProcessedOnionPacket, + closest_inc_cltv_abs: int, + total_msat: int, # total_msat of the outer onion + any_trampoline_onion: ProcessedOnionPacket, # any trampoline onion of the incoming htlc set, they should be similar fw_payment_key: str, ) -> None: @@ -3565,7 +3666,7 @@ class LNWallet(LNWorker): if not (forwarding_enabled and forwarding_trampoline_enabled): self.logger.info(f"trampoline forwarding is disabled. failing htlc.") raise OnionRoutingFailure(code=OnionFailureCode.PERMANENT_CHANNEL_FAILURE, data=b'') - payload = trampoline_onion.hop_data.payload + payload = any_trampoline_onion.hop_data.payload payment_data = payload.get('payment_data') try: payment_secret = payment_data['payment_secret'] if payment_data else os.urandom(32) @@ -3583,7 +3684,7 @@ class LNWallet(LNWorker): else: self.logger.info('forward_trampoline: end-to-end') invoice_features = LnFeatures.BASIC_MPP_OPT - next_trampoline_onion = trampoline_onion.next_packet + next_trampoline_onion = any_trampoline_onion.next_packet r_tags = [] except Exception as e: self.logger.exception('') @@ -3597,13 +3698,12 @@ class LNWallet(LNWorker): # these are the fee/cltv paid by the sender # pay_to_node will raise if they are not sufficient - total_msat = outer_onion.hop_data.payload["payment_data"]["total_msat"] budget = PaymentFeeBudget( fee_msat=total_msat - amt_to_forward, - cltv=inc_cltv_abs - out_cltv_abs, + cltv=closest_inc_cltv_abs - out_cltv_abs, ) self.logger.info(f'trampoline forwarding. budget={budget}') - self.logger.info(f'trampoline forwarding. {inc_cltv_abs=}, {out_cltv_abs=}') + self.logger.info(f'trampoline forwarding. {closest_inc_cltv_abs=}, {out_cltv_abs=}') # To convert abs vs rel cltvs, we need to guess blockheight used by original sender as "current blockheight". # Blocks might have been mined since. # - if we skew towards the past, we decrease our own cltv_budget accordingly (which is ok) @@ -3715,7 +3815,6 @@ class LNWallet(LNWorker): final_cltv_abs=final_cltv_abs, total_msat=total_msat, payment_secret=payment_secret) - num_hops = len(hops_data) self.logger.info(f"pay len(route)={len(route)}. for payment_hash={payment_hash.hex()}") for i in range(len(route)): self.logger.info(f" {i}: edge={route[i].short_channel_id} hop_data={hops_data[i]!r}") diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index 3705aa572..de6a40a54 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -37,8 +37,7 @@ from .util import ( run_sync_function_on_asyncio_thread, trigger_callback, NoDynamicFeeEstimates, UserFacingException, ) from . import lnutil -from .lnutil import (hex_to_bytes, REDEEM_AFTER_DOUBLE_SPENT_DELAY, Keypair, - MIN_FINAL_CLTV_DELTA_ACCEPTED) +from .lnutil import hex_to_bytes, REDEEM_AFTER_DOUBLE_SPENT_DELAY, Keypair from .lnaddr import lndecode from .json_db import StoredObject, stored_in from . import constants @@ -257,7 +256,7 @@ class SwapManager(Logger): payment_hash = bytes.fromhex(payment_hash_hex) swap._payment_hash = payment_hash self._add_or_reindex_swap(swap, is_new=False) - if not swap.is_reverse and not swap.is_redeemed: + if not swap.is_reverse and not swap.is_redeemed and not self.lnworker.get_preimage(swap.payment_hash): self.lnworker.register_hold_invoice(payment_hash, self.hold_invoice_callback) self._prepayments = {} # type: Dict[bytes, bytes] # fee_rhash -> rhash @@ -399,8 +398,8 @@ class SwapManager(Logger): def _fail_swap(self, swap: SwapData, reason: str): self.logger.info(f'failing swap {swap.payment_hash.hex()}: {reason}') if not swap.is_reverse and swap.payment_hash in self.lnworker.hold_invoice_callbacks: + # unregister_hold_invoice will fail pending htlcs if there is no preimage available self.lnworker.unregister_hold_invoice(swap.payment_hash) - # Peer.maybe_fulfill_htlc will fail incoming htlcs if there is no payment info self.lnworker.delete_payment_info(swap.payment_hash.hex()) self.lnworker.clear_invoices_cache() self.lnwatcher.remove_callback(swap.lockup_address) @@ -415,7 +414,7 @@ class SwapManager(Logger): self._prepayments.pop(swap.prepay_hash, None) if self.lnworker.get_payment_status(swap.prepay_hash) != PR_PAID: self.lnworker.delete_payment_info(swap.prepay_hash.hex()) - self.lnworker.delete_payment_bundle(swap.payment_hash) + self.lnworker.delete_payment_bundle(payment_hash=swap.payment_hash) @classmethod def extract_preimage(cls, swap: SwapData, claim_tx: Transaction) -> Optional[bytes]: @@ -473,7 +472,7 @@ class SwapManager(Logger): # cleanup self.lnwatcher.remove_callback(swap.lockup_address) if not swap.is_reverse: - self.lnworker.delete_payment_bundle(swap.payment_hash) + self.lnworker.delete_payment_bundle(payment_hash=swap.payment_hash) self.lnworker.unregister_hold_invoice(swap.payment_hash) if not swap.is_reverse: @@ -690,7 +689,7 @@ class SwapManager(Logger): self.lnworker.add_payment_info_for_hold_invoice( payment_hash, lightning_amount_sat=invoice_amount_sat, - min_final_cltv_delta=min_final_cltv_expiry_delta or MIN_FINAL_CLTV_DELTA_ACCEPTED, + min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED, exp_delay=300, ) info = self.lnworker.get_payment_info(payment_hash) @@ -709,7 +708,7 @@ class SwapManager(Logger): if prepay: prepay_hash = self.lnworker.create_payment_info( amount_msat=prepay_amount_sat*1000, - min_final_cltv_delta=min_final_cltv_expiry_delta or MIN_FINAL_CLTV_DELTA_ACCEPTED, + min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED, exp_delay=300, ) info = self.lnworker.get_payment_info(prepay_hash) diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index 09cf9d948..bf22f03fb 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -22,33 +22,29 @@ # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import os -import ast import datetime import json import copy -import threading from collections import defaultdict from typing import (Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING, Union, AbstractSet) -import binascii import time from functools import partial import attr -from . import util, bitcoin -from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, bfh, MyEncoder -from .invoices import Invoice, Request +from . import bitcoin +from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, MyEncoder from .keystore import bip44_derivation from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput, BadHeaderMagic from .logging import Logger -from .lnutil import HTLCOwner, ChannelType +from .lnutil import HTLCOwner, ChannelType, RecvMPPResolution from . import json_db -from .json_db import StoredDict, JsonDB, locked, modifier, StoredObject, stored_in, stored_as +from .json_db import JsonDB, locked, modifier, StoredObject, stored_in, stored_as from .plugin import run_hook, plugin_loaders from .version import ELECTRUM_VERSION +from .i18n import _ if TYPE_CHECKING: from .storage import WalletStorage @@ -73,7 +69,7 @@ class WalletUnfinished(WalletFileException): # seed_version is now used for the version of the wallet file OLD_SEED_VERSION = 4 # electrum versions < 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0 -FINAL_SEED_VERSION = 62 # electrum >= 2.7 will set this to prevent +FINAL_SEED_VERSION = 63 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format @@ -238,6 +234,7 @@ class WalletDBUpgrader(Logger): self._convert_version_60() self._convert_version_61() self._convert_version_62() + self._convert_version_63() self.put('seed_version', FINAL_SEED_VERSION) # just to be sure def _convert_wallet_type(self): @@ -1182,6 +1179,96 @@ class WalletDBUpgrader(Logger): swap['claim_to_output'] = None self.data['seed_version'] = 62 + def _convert_version_63(self): + if not self._is_upgrade_method_needed(62, 62): + return + # Old ReceivedMPPStatus: + # class ReceivedMPPStatus(NamedTuple): + # resolution: RecvMPPResolution + # expected_msat: int + # htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] + # + # New ReceivedMPPStatus: + # class ReceivedMPPStatus(NamedTuple): + # resolution: RecvMPPResolution + # htlcs: set[ReceivedMPPHtlc] + # + # class ReceivedMPPHtlc(NamedTuple): + # scid: ShortChannelID + # htlc: UpdateAddHtlc + # unprocessed_onion: str + + # previously chan.unfulfilled_htlcs went through 4 stages: + # - 1. not forwarded yet: (onion_packet_hex, None) + # - 2. forwarded: (onion_packet_hex, forwarding_key) + # - 3. processed: (None, forwarding_key), not irrevocably removed yet + # - 4. done: (None, forwarding_key), irrevocably removed + channels = self.data.get('channels', {}) + def _move_unprocessed_onion(short_channel_id: str, htlc_id: Optional[int]) -> Optional[Tuple[str, Optional[str]]]: + if htlc_id is None: + return None + for chan_ in channels.values(): + if chan_['short_channel_id'] != short_channel_id: + continue + unfulfilled_htlcs_ = chan_.get('unfulfilled_htlcs', {}) + htlc_data = unfulfilled_htlcs_.get(str(htlc_id)) + if htlc_data is None: + return None + stored_onion_packet, htlc_forwarding_key = htlc_data + if stored_onion_packet is not None: + htlc_data[0] = None # overwrite the onion so it is not processed again in htlc_switch + return stored_onion_packet, htlc_forwarding_key + return None + + mpp_sets = self.data.get('received_mpp_htlcs', {}) + for payment_key, recv_mpp_status in list(mpp_sets.items()): + assert isinstance(recv_mpp_status, list), f"{recv_mpp_status=}" + del recv_mpp_status[1] # remove expected_msat + + new_type_htlcs = [] + forwarding_key = None + for scid, update_add_htlc in recv_mpp_status[1]: # htlc_set + htlc_info_from_chan = _move_unprocessed_onion(scid, update_add_htlc[3]) + if htlc_info_from_chan is None: + # if there is no onion packet for the htlc it is dropped as it was already + # processed in the old htlc_switch + continue + onion_packet_hex = htlc_info_from_chan[0] + forwarding_key = htlc_info_from_chan[1] if htlc_info_from_chan[1] else forwarding_key + new_type_htlcs.append([ + scid, + update_add_htlc, + onion_packet_hex, + ]) + + if len(new_type_htlcs) == 0: + self.logger.debug(f"_convert_version_62: dropping mpp set {payment_key=}.") + del mpp_sets[payment_key] + else: + recv_mpp_status[1] = new_type_htlcs + self.logger.debug(f"_convert_version_62: migrated mpp set {payment_key=}") + if forwarding_key is not None: + # if the forwarding key is set for the old mpp set it was either a forwarding + # or a swap hold invoice. Assuming users of 4.6.2 don't use forwarding this update + # most likely happens during a swap waiting for the preimage. Setting the mpp set + # to SETTLING prevents us from accidentally failing the htlc set after the update, + # however it carries the risk of the channel getting force closed if the swap fails + # as the htlcs won't get failed due to the new SETTLING state + # unless a forwarding error is set. + recv_mpp_status[0] = 4 # RecvMPPResolution.SETTLING + + # replace Tuple[onion, forwarding_key] with just the onion in chan['unfulfilled_htlcs'] + for chan in channels.values(): + unfulfilled_htlcs = chan.get('unfulfilled_htlcs', {}) + for htlc_id, (unprocessed_onion, forwarding_key) in list(unfulfilled_htlcs.items()): + if unprocessed_onion is None: + # delete all unfulfilled_htlcs with empty onion as they are already processed + del unfulfilled_htlcs[htlc_id] + else: + unfulfilled_htlcs[htlc_id] = unprocessed_onion + + self.data['seed_version'] = 63 + def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): return diff --git a/tests/test_commands.py b/tests/test_commands.py index 28a80a607..455f7b87e 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -549,13 +549,13 @@ class TestCommandsTestnet(ElectrumTestCase): ) mock_htlc1 = mock.Mock() - mock_htlc1.cltv_abs = 800_000 - mock_htlc1.amount_msat = 4_500_000 + mock_htlc1.htlc.cltv_abs = 800_000 + mock_htlc1.htlc.amount_msat = 4_500_000 mock_htlc2 = mock.Mock() - mock_htlc2.cltv_abs = 800_144 - mock_htlc2.amount_msat = 5_500_000 + mock_htlc2.htlc.cltv_abs = 800_144 + mock_htlc2.htlc.amount_msat = 5_500_000 mock_htlc_status = mock.Mock() - mock_htlc_status.htlc_set = [(None, mock_htlc1), (None, mock_htlc2)] + mock_htlc_status.htlcs = [mock_htlc1, mock_htlc2] mock_htlc_status.resolution = RecvMPPResolution.COMPLETE payment_key = wallet.lnworker._get_payment_key(bytes.fromhex(payment_hash)).hex() diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 360ede67e..fbd148f70 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -301,7 +301,6 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): set_request_status = LNWallet.set_request_status set_payment_status = LNWallet.set_payment_status get_payment_status = LNWallet.get_payment_status - check_mpp_status = LNWallet.check_mpp_status htlc_fulfilled = LNWallet.htlc_fulfilled htlc_failed = LNWallet.htlc_failed save_preimage = LNWallet.save_preimage @@ -334,11 +333,9 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): unregister_hold_invoice = LNWallet.unregister_hold_invoice add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice - update_mpp_with_received_htlc = LNWallet.update_mpp_with_received_htlc + update_or_create_mpp_with_received_htlc = LNWallet.update_or_create_mpp_with_received_htlc set_mpp_resolution = LNWallet.set_mpp_resolution - is_mpp_amount_reached = LNWallet.is_mpp_amount_reached get_mpp_amounts = LNWallet.get_mpp_amounts - get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp bundle_payments = LNWallet.bundle_payments get_payment_bundle = LNWallet.get_payment_bundle _get_payment_key = LNWallet._get_payment_key @@ -347,11 +344,14 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding current_target_feerate_per_kw = LNWallet.current_target_feerate_per_kw current_low_feerate_per_kw_srk_channel = LNWallet.current_low_feerate_per_kw_srk_channel - maybe_cleanup_mpp = LNWallet.maybe_cleanup_mpp create_onion_for_route = LNWallet.create_onion_for_route - maybe_forward_htlc = LNWallet.maybe_forward_htlc - maybe_forward_trampoline = LNWallet.maybe_forward_trampoline + maybe_forward_htlc_set = LNWallet.maybe_forward_htlc_set + _maybe_forward_htlc = LNWallet._maybe_forward_htlc + _maybe_forward_trampoline = LNWallet._maybe_forward_trampoline _maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created = LNWallet._maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created + set_htlc_set_error = LNWallet.set_htlc_set_error + is_payment_bundle_complete = LNWallet.is_payment_bundle_complete + delete_payment_bundle = LNWallet.delete_payment_bundle _process_htlc_log = LNWallet._process_htlc_log