diff --git a/electrum/commands.py b/electrum/commands.py index 94d0d192b..f249f55f9 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -1508,7 +1508,7 @@ class Commands(Logger): payment_key: str = wallet.lnworker._get_payment_key(bfh(payment_hash)).hex() htlc_status = wallet.lnworker.received_mpp_htlcs[payment_key] result["closest_htlc_expiry_height"] = min( - htlc.cltv_abs for _, htlc in htlc_status.htlc_set + mpp_htlc.htlc.cltv_abs for mpp_htlc in htlc_status.htlcs ) elif wallet.lnworker.get_preimage_hex(payment_hash) is not None \ and payment_hash not in wallet.lnworker.dont_settle_htlcs: diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index dd5fedaca..36c2dad7f 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -783,8 +783,8 @@ class Channel(AbstractChannel): self.onion_keys = state['onion_keys'] # type: Dict[int, bytes] self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp'] self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate) - self.unfulfilled_htlcs = state["unfulfilled_htlcs"] # type: Dict[int, Tuple[str, Optional[str]]] - # ^ htlc_id -> onion_packet_hex, forwarding_key + self.unfulfilled_htlcs = state["unfulfilled_htlcs"] # type: Dict[int, Optional[str]] + # ^ htlc_id -> onion_packet_hex self._state = ChannelState[state['state']] self.peer_state = PeerState.DISCONNECTED self._outgoing_channel_update = None # type: Optional[bytes] @@ -1112,6 +1112,7 @@ class Channel(AbstractChannel): if amount_msat <= 0: raise PaymentFailure("HTLC value must be positive") if amount_msat < chan_config.htlc_minimum_msat: + # todo: for incoming htlcs this could be handled more gracefully with `amount_below_minimum` raise PaymentFailure(f'HTLC value too small: {amount_msat} msat') if self.htlc_slots_left(htlc_proposer) == 0: @@ -1226,7 +1227,7 @@ class Channel(AbstractChannel): with self.db_lock: self.hm.recv_htlc(htlc) if onion_packet: - self.unfulfilled_htlcs[htlc.htlc_id] = onion_packet.hex(), None + self.unfulfilled_htlcs[htlc.htlc_id] = onion_packet.hex() self.logger.info("receive_htlc") return htlc diff --git a/electrum/lnonion.py b/electrum/lnonion.py index f624627e1..b57c04792 100644 --- a/electrum/lnonion.py +++ b/electrum/lnonion.py @@ -26,7 +26,8 @@ import io import hashlib from functools import cached_property -from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union, Mapping +from typing import (Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union, + Mapping, Iterator) from enum import IntEnum from dataclasses import dataclass, field, replace from types import MappingProxyType @@ -485,6 +486,55 @@ def process_onion_packet( return ProcessedOnionPacket(are_we_final, hop_data, next_onion_packet, trampoline_onion_packet) +def compare_trampoline_onions( + trampoline_onions: Iterator[Optional[ProcessedOnionPacket]], + *, + exclude_amt_to_fwd: bool = False, +) -> bool: + """ + compare values of trampoline onions payloads and are_we_final. + If we are receiver of a multi trampoline payment amt_to_fwd can differ between the trampoline + parts of the payment, so it needs to be excluded from the comparison when comparing all trampoline + onions of the whole payment (however it can be compared between the onions in a single trampoline part). + """ + try: + first_onion = next(trampoline_onions) + except StopIteration: + raise ValueError("nothing to compare") + + if first_onion is None: + # we don't support mixed mpp sets of htlcs with trampoline onions and regular non-trampoline htlcs. + # In theory this could happen if a sender e.g. uses trampoline as fallback to deliver + # outstanding mpp parts if local pathfinding wasn't successful for the whole payment, + # resulting in a mixed payment. However, it's not even clear if the spec allows for such a constellation. + return all(onion is None for onion in trampoline_onions) + assert isinstance(first_onion, ProcessedOnionPacket), f"{first_onion=}" + + are_we_final = first_onion.are_we_final + payload = first_onion.hop_data.payload + total_msat = first_onion.total_msat + outgoing_cltv = first_onion.outgoing_cltv_value + payment_secret = first_onion.payment_secret + for onion in trampoline_onions: + if onion is None: + return False + assert isinstance(onion, ProcessedOnionPacket), f"{onion=}" + assert onion.trampoline_onion_packet is None, f"{onion=} cannot have trampoline_onion_packet" + if onion.are_we_final != are_we_final: + return False + if not exclude_amt_to_fwd: + if onion.hop_data.payload != payload: + return False + else: + if onion.total_msat != total_msat: + return False + if onion.outgoing_cltv_value != outgoing_cltv: + return False + if onion.payment_secret != payment_secret: + return False + return True + + class FailedToDecodeOnionError(Exception): pass @@ -521,6 +571,14 @@ class OnionRoutingFailure(Exception): payload = None return payload + def to_wire_msg(self, onion_packet: OnionPacket, privkey: bytes, local_height: int) -> bytes: + onion_error = construct_onion_error(self, onion_packet.public_key, privkey, local_height) + error_bytes = obfuscate_onion_error(onion_error, onion_packet.public_key, privkey) + return error_bytes + + +class OnionParsingError(OnionRoutingFailure): pass + def construct_onion_error( error: OnionRoutingFailure, diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 72db12657..377dd375f 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -21,6 +21,7 @@ from electrum_ecc import ecdsa_sig64_from_r_and_s, ecdsa_der_sig_from_ecdsa_sig6 import aiorpcx from aiorpcx import ignore_after +from .lrucache import LRUCache from .crypto import sha256, sha256d, privkey_to_pubkey from . import bitcoin, util from . import constants @@ -31,11 +32,11 @@ from . import transaction from .bitcoin import make_op_return, DummyAddress from .transaction import PartialTxOutput, match_script_against_template, Sighash from .logging import Logger -from .lnrouter import RouteEdge -from .lnonion import (new_onion_packet, OnionFailureCode, calc_hops_data_for_payment, process_onion_packet, - OnionPacket, construct_onion_error, obfuscate_onion_error, OnionRoutingFailure, - ProcessedOnionPacket, UnsupportedOnionPacketVersion, InvalidOnionMac, InvalidOnionPubkey, - OnionFailureCodeMetaFlag) +from . import lnonion +from .lnonion import (OnionFailureCode, OnionPacket, obfuscate_onion_error, + OnionRoutingFailure, ProcessedOnionPacket, UnsupportedOnionPacketVersion, + InvalidOnionMac, InvalidOnionPubkey, OnionFailureCodeMetaFlag, + OnionParsingError) from .lnchannel import Channel, RevokeAndAck, ChannelState, PeerState, ChanCloseOption, CF_ANNOUNCE_CHANNEL from . import lnutil from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConfig, @@ -46,17 +47,15 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConf ln_compare_features, MIN_FINAL_CLTV_DELTA_ACCEPTED, RemoteMisbehaving, ShortChannelID, IncompatibleLightningFeatures, ChannelType, LNProtocolWarning, validate_features, - IncompatibleOrInsaneFeatures, FeeBudgetExceeded, + IncompatibleOrInsaneFeatures, ReceivedMPPStatus, ReceivedMPPHtlc, GossipForwardingMessage, GossipTimestampFilter, channel_id_from_funding_tx, - PaymentFeeBudget, serialize_htlc_key, Keypair, RecvMPPResolution) + serialize_htlc_key, Keypair, RecvMPPResolution) from .lntransport import LNTransport, LNTransportBase, LightningPeerConnectionClosed, HandshakeFailed from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType, FailedToParseMsg from .interface import GracefulDisconnect -from .lnrouter import fee_for_edge_msat from .json_db import StoredDict from .invoices import PR_PAID from .fee_policy import FEE_LN_ETA_TARGET, FEERATE_PER_KW_MIN_RELAY_LIGHTNING -from .trampoline import decode_routing_info if TYPE_CHECKING: from .lnworker import LNGossip, LNWallet @@ -132,6 +131,7 @@ class Peer(Logger, EventListener): self.downstream_htlc_resolved_event = asyncio.Event() self.register_callbacks() self._num_gossip_messages_forwarded = 0 + self._processed_onion_cache = LRUCache(maxsize=100) # type: LRUCache[bytes, ProcessedOnionPacket] def send_message(self, message_name: str, **kwargs): assert util.get_running_loop() == util.get_asyncio_loop(), f"this must be run on the asyncio thread!" @@ -2136,177 +2136,151 @@ class Peer(Logger, EventListener): return payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd - def check_mpp_is_waiting( - self, - *, - payment_secret: bytes, - short_channel_id: ShortChannelID, + def _check_unfulfilled_htlc( + self, *, + chan: Channel, htlc: UpdateAddHtlc, - expected_msat: int, - exc_incorrect_or_unknown_pd: OnionRoutingFailure, - log_fail_reason: Callable[[str], None], - ) -> bool: - mpp_resolution = self.lnworker.check_mpp_status( - payment_secret=payment_secret, - short_channel_id=short_channel_id, - htlc=htlc, - expected_msat=expected_msat, - ) - if mpp_resolution == RecvMPPResolution.WAITING: - return True - elif mpp_resolution == RecvMPPResolution.EXPIRED: - log_fail_reason(f"MPP_TIMEOUT") - raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') - elif mpp_resolution == RecvMPPResolution.FAILED: - log_fail_reason(f"mpp_resolution is FAILED") - raise exc_incorrect_or_unknown_pd - elif mpp_resolution == RecvMPPResolution.COMPLETE: - return False - else: - raise Exception(f"unexpected {mpp_resolution=}") - - def maybe_fulfill_htlc( - self, *, - chan: Channel, - htlc: UpdateAddHtlc, - processed_onion: ProcessedOnionPacket, - outer_onion_payment_secret: bytes = None, # used to group trampoline htlcs for forwarding - onion_packet_bytes: bytes, - already_forwarded: bool = False, - ) -> Tuple[Optional[bytes], Optional[Tuple[str, Callable[[], Awaitable[Optional[str]]]]]]: + processed_onion: ProcessedOnionPacket, + outer_onion_payment_secret: bytes = None, # used to group trampoline htlcs for forwarding + ) -> str: """ - Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded. - Return (preimage, (payment_key, callback)) with at most a single element not None. + Does additional checks on the incoming htlc and return the payment key if the tests pass, + otherwise raises OnionRoutingError which will get the htlc failed. """ - if not processed_onion.are_we_final: - if not self.lnworker.enable_htlc_forwarding: - return None, None - # use the htlc key if we are forwarding - payment_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id) - callback = lambda: self.lnworker.maybe_forward_htlc( - incoming_chan=chan, - htlc=htlc, - processed_onion=processed_onion) - return None, (payment_key, callback) + _log_fail_reason = self._log_htlc_fail_reason_cb(chan.short_channel_id, htlc, processed_onion.hop_data.payload) - def log_fail_reason(reason: str): - self.logger.info( - f"maybe_fulfill_htlc. will FAIL HTLC: chan {chan.short_channel_id}. " - f"{reason}. htlc={str(htlc)}. onion_payload={processed_onion.hop_data.payload}") - - chain = self.network.blockchain() # Check that our blockchain tip is sufficiently recent so that we have an approx idea of the height. # We should not release the preimage for an HTLC that its sender could already time out as # then they might try to force-close and it becomes a race. - if chain.is_tip_stale() and not already_forwarded: - log_fail_reason(f"our chain tip is stale") - raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') + chain = self.network.blockchain() local_height = chain.height() + blocks_to_expiry = max(htlc.cltv_abs - local_height, 0) + if chain.is_tip_stale(): + _log_fail_reason(f"our chain tip is stale: {local_height=}") + raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') + + payment_hash = htlc.payment_hash + if not processed_onion.are_we_final: + if outer_onion_payment_secret: + # this is a trampoline forwarding htlc, multiple incoming trampoline htlcs can be collected + payment_key = (payment_hash + outer_onion_payment_secret).hex() + return payment_key + # this is a regular htlc to forward, it will get its own set of size 1 keyed by htlc_key + # Additional checks required only for forwarding nodes will be done in maybe_forward_htlc(). + payment_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id) + return payment_key # parse parameters and perform checks that are invariant - payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd = self._check_accepted_final_htlc( - chan=chan, - htlc=htlc, - processed_onion=processed_onion, - is_trampoline_onion=bool(outer_onion_payment_secret), - log_fail_reason=log_fail_reason) - - # payment key for final onions - payment_hash = htlc.payment_hash - payment_key = (payment_hash + payment_secret_from_onion).hex() - - if self.check_mpp_is_waiting( - payment_secret=payment_secret_from_onion, - short_channel_id=chan.get_scid_or_local_alias(), + payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd = ( + self._check_accepted_final_htlc( + chan=chan, htlc=htlc, - expected_msat=total_msat, - exc_incorrect_or_unknown_pd=exc_incorrect_or_unknown_pd, - log_fail_reason=log_fail_reason, - ): - return None, None + processed_onion=processed_onion, + is_trampoline_onion=bool(outer_onion_payment_secret), + log_fail_reason=_log_fail_reason, + )) + # trampoline htlcs of which we are the final receiver will first get grouped by the outer + # onions secret to allow grouping a multi-trampoline mpp in different sets. Once a trampoline + # payment part is completed (sum(htlcs) >= (trampoline-)amt_to_forward), its htlcs get moved into + # the htlc set representing the whole payment (payment key derived from trampoline/invoice secret). + payment_key = (payment_hash + (outer_onion_payment_secret or payment_secret_from_onion)).hex() - # TODO check against actual min_final_cltv_expiry_delta from invoice (and give 2-3 blocks of leeway?) - # note: payment_bundles might get split here, e.g. one payment is "already forwarded" and the other is not. - # In practice, for the swap prepayment use case, this does not matter. - if local_height + MIN_FINAL_CLTV_DELTA_ACCEPTED > htlc.cltv_abs and not already_forwarded: - log_fail_reason(f"htlc.cltv_abs is unreasonably close") + if blocks_to_expiry < MIN_FINAL_CLTV_DELTA_ACCEPTED: + # this check should be done here for new htlcs and ongoing on pending sets. + # Here it is done so that invalid received htlcs will never get added to a set, + # so the set still has a chance to succeed until mpp timeout. + _log_fail_reason(f"htlc.cltv_abs is unreasonably close: {htlc.cltv_abs=}, {local_height=}") raise exc_incorrect_or_unknown_pd - # detect callback - # if there is a trampoline_onion, maybe_fulfill_htlc will be called again - # order is important: if we receive a trampoline onion for a hold invoice, we need to peel the onion first. - + # extract trampoline if processed_onion.trampoline_onion_packet: - # TODO: we should check that all trampoline_onions are the same - trampoline_onion = self.process_onion_packet( + trampoline_onion = self._process_incoming_onion_packet( processed_onion.trampoline_onion_packet, payment_hash=payment_hash, - onion_packet_bytes=onion_packet_bytes, is_trampoline=True) - if trampoline_onion.are_we_final: - # trampoline- we are final recipient of HTLC - # note: the returned payment_key will contain the inner payment_secret - return self.maybe_fulfill_htlc( - chan=chan, - htlc=htlc, - processed_onion=trampoline_onion, - outer_onion_payment_secret=payment_secret_from_onion, - onion_packet_bytes=onion_packet_bytes, - already_forwarded=already_forwarded, - ) - else: - callback = lambda: self.lnworker.maybe_forward_trampoline( - payment_hash=payment_hash, - inc_cltv_abs=htlc.cltv_abs, # TODO: use max or enforce same value across mpp parts - outer_onion=processed_onion, - trampoline_onion=trampoline_onion, - fw_payment_key=payment_key) - return None, (payment_key, callback) - # TODO don't accept payments twice for same invoice - # note: we don't check invoice expiry (bolt11 'x' field) on the receiver-side. - # - semantics are weird: would make sense for simple-payment-receives, but not - # if htlc is expected to be pending for a while, e.g. for a hold-invoice. + # compare trampoline onion against outer onion according to: + # https://github.com/lightning/bolts/blob/9938ab3d6160a3ba91f3b0e132858ab14bfe4f81/04-onion-routing.md?plain=1#L547-L553 + if trampoline_onion.are_we_final: + try: + assert not processed_onion.outgoing_cltv_value < trampoline_onion.outgoing_cltv_value + is_mpp = processed_onion.total_msat > processed_onion.amt_to_forward + if is_mpp: + assert not processed_onion.total_msat < trampoline_onion.amt_to_forward + else: + assert not processed_onion.amt_to_forward < trampoline_onion.amt_to_forward + except AssertionError: + _log_fail_reason(f'incorrect trampoline onion {processed_onion=}\n{trampoline_onion=}') + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') + + return self._check_unfulfilled_htlc( + chan=chan, + htlc=htlc, + processed_onion=trampoline_onion, + outer_onion_payment_secret=payment_secret_from_onion, + ) + info = self.lnworker.get_payment_info(payment_hash) if info is None: - log_fail_reason(f"no payment_info found for RHASH {htlc.payment_hash.hex()}") + _log_fail_reason(f"no payment_info found for RHASH {payment_hash.hex()}") + raise exc_incorrect_or_unknown_pd + elif info.status == PR_PAID: + _log_fail_reason(f"invoice already paid: {payment_hash.hex()=}") + raise exc_incorrect_or_unknown_pd + elif blocks_to_expiry < info.min_final_cltv_delta: + _log_fail_reason( + f"min final cltv delta lower than requested: " + f"{payment_hash.hex()=} {htlc.cltv_abs=} {blocks_to_expiry=}" + ) + raise exc_incorrect_or_unknown_pd + elif htlc.timestamp > info.expiration_ts: # the set will get failed too if now > exp_ts + _log_fail_reason(f"not accepting htlc for expired invoice") raise exc_incorrect_or_unknown_pd - preimage = self.lnworker.get_preimage(payment_hash) - expected_payment_secret = self.lnworker.get_payment_secret(htlc.payment_hash) - if payment_secret_from_onion != expected_payment_secret: - log_fail_reason(f'incorrect payment secret {payment_secret_from_onion.hex()} != {expected_payment_secret.hex()}') + expected_payment_secret = self.lnworker.get_payment_secret(payment_hash) + if not util.constant_time_compare(payment_secret_from_onion, expected_payment_secret): + _log_fail_reason(f'incorrect payment secret: {payment_secret_from_onion.hex()=}') raise exc_incorrect_or_unknown_pd + invoice_msat = info.amount_msat if channel_opening_fee: + # deduct just-in-time channel fees from invoice amount invoice_msat -= channel_opening_fee if not (invoice_msat is None or invoice_msat <= total_msat <= 2 * invoice_msat): - log_fail_reason(f"total_msat={total_msat} too different from invoice_msat={invoice_msat}") + _log_fail_reason(f"{total_msat=} too different from {invoice_msat=}") raise exc_incorrect_or_unknown_pd - hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash) - if hold_invoice_callback and not preimage: - callback = lambda: hold_invoice_callback(payment_hash) - return None, (payment_key, callback) + return payment_key - if payment_hash.hex() in self.lnworker.dont_settle_htlcs: - return None, None + def _fulfill_htlc_set(self, payment_key: str, preimage: bytes): + htlc_set = self.lnworker.received_mpp_htlcs[payment_key] + assert len(htlc_set.htlcs) > 0, f"{htlc_set=}" + assert htlc_set.resolution == RecvMPPResolution.SETTLING + assert htlc_set.parent_set_key is None, f"Must not settle child {htlc_set=}" + # get payment hash of any htlc in the set (they are all the same) + payment_hash = htlc_set.get_payment_hash() + assert payment_hash is not None, htlc_set + for mpp_htlc in list(htlc_set.htlcs): + htlc_id = mpp_htlc.htlc.htlc_id + chan = self.lnworker.get_channel_by_short_id(mpp_htlc.scid) + if chan.channel_id not in self.channels: + # this htlc belongs to another peer and has to be settled in their htlc_switch + continue + if not chan.can_update_ctx(proposer=LOCAL): + continue + self.logger.info(f"fulfill htlc: {chan.short_channel_id}. {htlc_id=}. {payment_hash.hex()=}") + if chan.hm.was_htlc_preimage_released(htlc_id=htlc_id, htlc_proposer=REMOTE): + # this check is intended to gracefully handle stale htlcs in the set, e.g. after a crash + self.logger.debug(f"{mpp_htlc=} was already settled before, dropping it.") + htlc_set.htlcs.remove(mpp_htlc) + continue + self._fulfill_htlc(chan, htlc_id, preimage) + htlc_set.htlcs.remove(mpp_htlc) + # reset just-in-time opening fee of channel + chan.opening_fee = None - if not preimage: - if not already_forwarded: - log_fail_reason(f"missing preimage and no hold invoice callback {payment_hash.hex()}") - raise exc_incorrect_or_unknown_pd - else: - return None, None - - chan.opening_fee = None - self.logger.info(f"maybe_fulfill_htlc. will FULFILL HTLC: chan {chan.short_channel_id}. htlc={str(htlc)}") - return preimage, None - - def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): - self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") - assert chan.can_update_ctx(proposer=LOCAL), f"cannot send updates: {chan.short_channel_id}" + def _fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id) self.received_htlcs_pending_removal.add((chan, htlc_id)) chan.settle_htlc(preimage, htlc_id) @@ -2316,6 +2290,61 @@ class Peer(Logger, EventListener): id=htlc_id, payment_preimage=preimage) + def _fail_htlc_set( + self, + payment_key: str, + error_tuple: Tuple[Optional[bytes], Optional[OnionFailureCode | int], Optional[bytes]], + ): + htlc_set = self.lnworker.received_mpp_htlcs[payment_key] + assert htlc_set.resolution in (RecvMPPResolution.FAILED, RecvMPPResolution.EXPIRED) + + raw_error, error_code, error_data = error_tuple + local_height = self.network.blockchain().height() + for mpp_htlc in list(htlc_set.htlcs): + chan = self.lnworker.get_channel_by_short_id(mpp_htlc.scid) + htlc_id = mpp_htlc.htlc.htlc_id + if chan.channel_id not in self.channels: + # this htlc belongs to another peer and has to be settled in their htlc_switch + continue + if not chan.can_update_ctx(proposer=LOCAL): + continue + assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id) + if chan.hm.was_htlc_failed(htlc_id=htlc_id, htlc_proposer=REMOTE): + # this check is intended to gracefully handle stale htlcs in the set, e.g. after a crash + self.logger.debug(f"{mpp_htlc=} was already failed before, dropping it.") + htlc_set.htlcs.remove(mpp_htlc) + continue + onion_packet = self._parse_onion_packet(mpp_htlc.unprocessed_onion) + processed_onion_packet = self._process_incoming_onion_packet( + onion_packet, + payment_hash=mpp_htlc.htlc.payment_hash, + is_trampoline=False, + ) + if raw_error: + error_bytes = obfuscate_onion_error(raw_error, onion_packet.public_key, self.privkey) + else: + assert isinstance(error_code, (OnionFailureCode, int)) + if error_code == OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS: + amount_to_forward = processed_onion_packet.amt_to_forward + # if this was a trampoline htlc we use the inner amount_to_forward as this is + # the value known by the sender + if processed_onion_packet.trampoline_onion_packet: + processed_trampoline_onion_packet = self._process_incoming_onion_packet( + processed_onion_packet.trampoline_onion_packet, + payment_hash=mpp_htlc.htlc.payment_hash, + is_trampoline=True, + ) + amount_to_forward = processed_trampoline_onion_packet.amt_to_forward + error_data = amount_to_forward.to_bytes(8, byteorder="big") + e = OnionRoutingFailure(code=error_code, data=error_data or b'') + error_bytes = e.to_wire_msg(onion_packet, self.privkey, local_height) + self.fail_htlc( + chan=chan, + htlc_id=htlc_id, + error_bytes=error_bytes, + ) + htlc_set.htlcs.remove(mpp_htlc) + def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes): self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") assert chan.can_update_ctx(proposer=LOCAL), f"cannot send updates: {chan.short_channel_id}" @@ -2328,7 +2357,7 @@ class Peer(Logger, EventListener): len=len(error_bytes), reason=error_bytes) - def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure): + def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionParsingError): self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") assert chan.can_update_ctx(proposer=LOCAL), f"cannot send updates: {chan.short_channel_id}" if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32): @@ -2764,79 +2793,107 @@ class Peer(Logger, EventListener): @util.profiler(min_threshold=0.02) def _run_htlc_switch_iteration(self): self._maybe_cleanup_received_htlcs_pending_removal() - # In this loop, an item of chan.unfulfilled_htlcs may go through 4 stages: - # - 1. not forwarded yet: (None, onion_packet_hex) - # - 2. forwarded: (forwarding_key, onion_packet_hex) - # - 3. processed: (forwarding_key, None), not irrevocably removed yet - # - 4. done: (forwarding_key, None), irrevocably removed + # htlc processing happens in two steps: + # 1. Step: Iterating through all channels and their pending htlcs, doing validation + # feasible for single htlcs (some checks only make sense on the whole mpp set) and + # then collecting these htlcs in a mpp set by payment key. + # HTLCs failing these checks will get failed directly and won't be added to any set. + # No htlcs will get settled in this step, settling only happens on complete mpp sets. + # If a new htlc belongs to a set which has already been failed, the htlc will be failed + # and not added to any set. + # Each htlc is only supposed to go through this first loop once when being received. for chan_id, chan in self.channels.items(): if not chan.can_update_ctx(proposer=LOCAL): continue self.maybe_send_commitment(chan) - done = set() unfulfilled = chan.unfulfilled_htlcs - for htlc_id, (onion_packet_hex, forwarding_key) in unfulfilled.items(): + for htlc_id, onion_packet_hex in list(unfulfilled.items()): if not chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): continue + htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id) - if chan.hm.is_htlc_irrevocably_removed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): - assert onion_packet_hex is None - self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc) - if forwarding_key: - self.lnworker.maybe_cleanup_forwarding(forwarding_key) - done.add(htlc_id) - continue - if onion_packet_hex is None: - # has been processed already - continue - error_reason = None # type: Optional[OnionRoutingFailure] - error_bytes = None # type: Optional[bytes] - preimage = None - onion_packet_bytes = bytes.fromhex(onion_packet_hex) - onion_packet = None try: - onion_packet = OnionPacket.from_bytes(onion_packet_bytes) - except OnionRoutingFailure as e: - error_reason = e - else: - try: - preimage, _forwarding_key, error_bytes = self.process_unfulfilled_htlc( - chan=chan, - htlc=htlc, - forwarding_key=forwarding_key, - onion_packet_bytes=onion_packet_bytes, - onion_packet=onion_packet) - if _forwarding_key: - assert forwarding_key is None - unfulfilled[htlc_id] = onion_packet_hex, _forwarding_key - except OnionRoutingFailure as e: - error_bytes = construct_onion_error(e, onion_packet.public_key, self.privkey, self.network.get_local_height()) - if error_bytes: - error_bytes = obfuscate_onion_error(error_bytes, onion_packet.public_key, our_onion_private_key=self.privkey) + onion_packet = self._parse_onion_packet(onion_packet_hex) + except OnionParsingError as e: + self.fail_malformed_htlc( + chan=chan, + htlc_id=htlc.htlc_id, + reason=e, + ) + del unfulfilled[htlc_id] + continue - if preimage or error_reason or error_bytes: - if preimage: - self.lnworker.set_request_status(htlc.payment_hash, PR_PAID) - if not self.lnworker.enable_htlc_settle: - continue - self.fulfill_htlc(chan, htlc.htlc_id, preimage) - elif error_bytes: - self.fail_htlc( - chan=chan, - htlc_id=htlc.htlc_id, - error_bytes=error_bytes) + try: + processed_onion_packet = self._process_incoming_onion_packet( + onion_packet, + payment_hash=htlc.payment_hash, + is_trampoline=False, + ) + payment_key: str = self._check_unfulfilled_htlc( + chan=chan, + htlc=htlc, + processed_onion=processed_onion_packet, + ) + self.lnworker.update_or_create_mpp_with_received_htlc( + payment_key=payment_key, + scid=chan.short_channel_id, + htlc=htlc, + unprocessed_onion_packet=onion_packet_hex, # outer onion if trampoline + ) + except OnionParsingError as e: # could be raised when parsing the inner trampoline onion + self.fail_malformed_htlc( + chan=chan, + htlc_id=htlc.htlc_id, + reason=e, + ) + except Exception as e: + # Fail the htlc directly if it fails to pass these tests, it will not get added to a htlc set. + # https://github.com/lightning/bolts/blob/14272b1bd9361750cfdb3e5d35740889a6b510b5/04-onion-routing.md?plain=1#L388 + reraise = False + if isinstance(e, OnionRoutingFailure): + orf = e else: - self.fail_malformed_htlc( - chan=chan, - htlc_id=htlc.htlc_id, - reason=error_reason) - # blank onion field to mark it as processed - unfulfilled[htlc_id] = None, forwarding_key + orf = OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') + reraise = True # propagate this out, as this might suggest a bug + error_bytes = orf.to_wire_msg(onion_packet, self.privkey, self.network.get_local_height()) + self.fail_htlc( + chan=chan, + htlc_id=htlc.htlc_id, + error_bytes=error_bytes, + ) + if reraise: + raise + finally: + del unfulfilled[htlc_id] - # cleanup - for htlc_id in done: - unfulfilled.pop(htlc_id) - self.maybe_send_commitment(chan) + # 2. Step: Acting on sets of htlcs. + # Doing further checks that have to be done on sets of htlcs (e.g. total amount checks) + # and checks that have to be done continuously like checking for timeout. + # A set marked as failed once must never settle any htlcs associated to it. + # The sets are shared between all peers, so each peers htlc_switch acts on the same sets. + for payment_key, htlc_set in list(self.lnworker.received_mpp_htlcs.items()): + any_error, preimage, callback = self._check_unfulfilled_htlc_set(payment_key, htlc_set) + assert bool(any_error) + bool(preimage) + bool(callback) <= 1, \ + f"{any_error=}, {bool(preimage)=}, {callback=}" + if any_error: + error_tuple = self.lnworker.set_htlc_set_error(payment_key, any_error) + self._fail_htlc_set(payment_key, error_tuple) + if preimage: + if self.lnworker.enable_htlc_settle: + self.lnworker.set_request_status(htlc_set.get_payment_hash(), PR_PAID) + self._fulfill_htlc_set(payment_key, preimage) + if callback: + task = asyncio.create_task(callback()) + task.add_done_callback( # log exceptions occurring in callback + lambda t, pk=payment_key: self.logger.exception( + f"cb failed: " + f"{self.lnworker.received_mpp_htlcs[pk]=}", exc_info=t.exception()) if t.exception() else None + ) + + if len(self.lnworker.received_mpp_htlcs[payment_key].htlcs) == 0: + self.logger.debug(f"deleting resolved mpp set: {payment_key=}") + del self.lnworker.received_mpp_htlcs[payment_key] + self.lnworker.maybe_cleanup_forwarding(payment_key) def _maybe_cleanup_received_htlcs_pending_removal(self) -> None: done = set() @@ -2861,107 +2918,332 @@ class Peer(Logger, EventListener): await group.spawn(htlc_switch_iteration()) await group.spawn(self.got_disconnected.wait()) - def process_unfulfilled_htlc( - self, *, - chan: Channel, - htlc: UpdateAddHtlc, - forwarding_key: Optional[str], - onion_packet_bytes: bytes, - onion_packet: OnionPacket) -> Tuple[Optional[bytes], Optional[str], Optional[bytes]]: - """ - return (preimage, payment_key, error_bytes) with at most a single element that is not None - raise an OnionRoutingFailure if we need to fail the htlc - """ - payment_hash = htlc.payment_hash - processed_onion = self.process_onion_packet( - onion_packet, - payment_hash=payment_hash, - onion_packet_bytes=onion_packet_bytes) + def _log_htlc_fail_reason_cb( + self, + scid: ShortChannelID, + htlc: UpdateAddHtlc, + onion_payload: dict + ) -> Callable[[str], None]: + def _log_fail_reason(reason: str) -> None: + self.logger.info(f"will FAIL HTLC: {str(scid)=}. {reason=}. {str(htlc)=}. {onion_payload=}") + return _log_fail_reason - preimage, forwarding_info = self.maybe_fulfill_htlc( - chan=chan, - htlc=htlc, - processed_onion=processed_onion, - onion_packet_bytes=onion_packet_bytes, - already_forwarded=bool(forwarding_key)) + def _log_htlc_set_fail_reason_cb(self, mpp_set: ReceivedMPPStatus) -> Callable[[str], None]: + def log_fail_reason(reason: str): + for mpp_htlc in mpp_set.htlcs: + try: + processed_onion = self._process_incoming_onion_packet( + onion_packet=self._parse_onion_packet(mpp_htlc.unprocessed_onion), + payment_hash=mpp_htlc.htlc.payment_hash, + is_trampoline=False, + ) + onion_payload = processed_onion.hop_data.payload + except Exception: + onion_payload = {} - if not forwarding_key: - if forwarding_info: - # HTLC we are supposed to forward, but haven't forwarded yet - payment_key, forwarding_callback = forwarding_info + self._log_htlc_fail_reason_cb( + mpp_htlc.scid, + mpp_htlc.htlc, + onion_payload, + )(f"mpp set {id(mpp_set)} failed: {reason}") + + return log_fail_reason + + def _check_unfulfilled_htlc_set( + self, + payment_key: str, + mpp_set: ReceivedMPPStatus + ) -> Tuple[ + Optional[Union[OnionRoutingFailure, OnionFailureCode, bytes]], # error types used to fail the set + Optional[bytes], # preimage to settle the set + Optional[Callable[[], Awaitable[None]]], # callback + ]: + """ + Returns what to do next with the given set of htlcs: + * Fail whole set -> returns error code + * Settle whole set -> Returns preimage + * call callback (e.g. forwarding, hold invoice) + May modify the mpp set in lnworker.received_mpp_htlcs (e.g. by setting its resolution to COMPLETE). + """ + _log_fail_reason = self._log_htlc_set_fail_reason_cb(mpp_set) + + if (final_state := self._check_final_mpp_set_state(payment_key, mpp_set)) is not None: + return final_state + + assert mpp_set.resolution in (RecvMPPResolution.WAITING, RecvMPPResolution.COMPLETE) + chain = self.network.blockchain() + local_height = chain.height() + if chain.is_tip_stale(): + _log_fail_reason(f"our chain tip is stale: {local_height=}") + return OnionFailureCode.TEMPORARY_NODE_FAILURE, None, None + + amount_msat: int = 0 # sum(amount_msat of each htlc) + total_msat = None # type: Optional[int] + payment_hash = mpp_set.get_payment_hash() + closest_cltv_abs = mpp_set.get_closest_cltv_abs() + first_htlc_timestamp = mpp_set.get_first_htlc_timestamp() + processed_onions = {} # type: dict[ReceivedMPPHtlc, Tuple[ProcessedOnionPacket, Optional[ProcessedOnionPacket]]] + for mpp_htlc in mpp_set.htlcs: + processed_onion = self._process_incoming_onion_packet( + onion_packet=self._parse_onion_packet(mpp_htlc.unprocessed_onion), + payment_hash=payment_hash, + is_trampoline=False, # this is always the outer onion + ) + processed_onions[mpp_htlc] = (processed_onion, None) + inner_onion = None + if processed_onion.trampoline_onion_packet: + inner_onion = self._process_incoming_onion_packet( + onion_packet=processed_onion.trampoline_onion_packet, + payment_hash=payment_hash, + is_trampoline=True, + ) + processed_onions[mpp_htlc] = (processed_onion, inner_onion) + + total_msat_outer_onion = processed_onion.total_msat + total_msat_inner_onion = inner_onion.total_msat if inner_onion else None + if total_msat is None: + total_msat = total_msat_inner_onion or total_msat_outer_onion + + # check total_msat is equal for all htlcs of the set + if total_msat != (total_msat_inner_onion or total_msat_outer_onion): + _log_fail_reason(f"total_msat is not uniform: {total_msat=} != {processed_onion.total_msat=}") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + + amount_msat += mpp_htlc.htlc.amount_msat + + # If the set contains outer onions with different payment secrets, the set's payment_key is + # derived from the trampoline/invoice/inner payment secret, so it is the second stage of a + # multi-trampoline payment in which all the trampoline parts/htlcs got combined. + # In this case the amt_to_forward cannot be compared as it may differ between the trampoline parts. + # However, amt_to_forward should be similar for all onions of a single trampoline part and gets + # compared in the first stage where the htlc set represents a single trampoline part. + outer_onions = [onions[0] for onions in processed_onions.values()] + can_have_different_amt_to_fwd = not all(o.payment_secret == outer_onions[0].payment_secret for o in outer_onions) + trampoline_onions = iter(onions[1] for onions in processed_onions.values()) + if not lnonion.compare_trampoline_onions(trampoline_onions, exclude_amt_to_fwd=can_have_different_amt_to_fwd): + _log_fail_reason(f"got inconsistent {trampoline_onions=}") + return OnionFailureCode.INVALID_ONION_PAYLOAD, None, None + + if len(processed_onions) == 1: + outer_onion, inner_onion = next(iter(processed_onions.values())) + if not outer_onion.are_we_final: + assert inner_onion is None, f"{outer_onion=}\n{inner_onion=}" if not self.lnworker.enable_htlc_forwarding: return None, None, None - if payment_key not in self.lnworker.active_forwardings: - async def wrapped_callback(): - forwarding_coro = forwarding_callback() - try: - next_htlc = await forwarding_coro - if next_htlc: - htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id) - self.lnworker.active_forwardings[payment_key].append(next_htlc) - self.lnworker.downstream_to_upstream_htlc[next_htlc] = htlc_key - except OnionRoutingFailure as e: - if len(self.lnworker.active_forwardings[payment_key]) == 0: - self.lnworker.save_forwarding_failure(payment_key, failure_message=e) - # TODO what about other errors? e.g. TxBroadcastError for a swap. - # - malicious electrum server could fake TxBroadcastError - # Could we "catch-all Exception" and fail back the htlcs with e.g. TEMPORARY_NODE_FAILURE? - # - we don't want to fail the inc-HTLC for a syntax error that happens in the callback - # If we don't call save_forwarding_failure(), the inc-HTLC gets stuck until expiry - # and then the inc-channel will get force-closed. - # => forwarding_callback() could have an API with two exceptions types: - # - type1, such as OnionRoutingFailure, that signals we need to fail back the inc-HTLC - # - type2, such as TxBroadcastError, that signals we want to retry the callback - # add to list - assert len(self.lnworker.active_forwardings.get(payment_key, [])) == 0 - self.lnworker.active_forwardings[payment_key] = [] - fut = asyncio.ensure_future(wrapped_callback()) - # return payment_key so this branch will not be executed again - return None, payment_key, None - elif preimage: - return preimage, None, None - else: - # we are waiting for mpp consolidation or preimage + # this is a single (non-trampoline) htlc set which needs to be forwarded. + # set to settling state so it will not be failed or forwarded twice. + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.SETTLING) + fwd_cb = lambda: self.lnworker.maybe_forward_htlc_set(payment_key, processed_htlc_set=processed_onions) + return None, None, fwd_cb + + assert payment_hash is not None and total_msat is not None + # check for expiry over time and potentially fail the whole set if any + # htlc's cltv becomes too close + blocks_to_expiry = max(0, closest_cltv_abs - local_height) + if blocks_to_expiry < MIN_FINAL_CLTV_DELTA_ACCEPTED: + _log_fail_reason(f"htlc.cltv_abs is unreasonably close") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + + # check for mpp expiry (if incomplete and expired -> fail) + if mpp_set.resolution == RecvMPPResolution.WAITING \ + or not self.lnworker.is_payment_bundle_complete(payment_key): + # maybe this set is COMPLETE but the bundle is not yet completed, so the bundle can be considered WAITING + if int(time.time()) - first_htlc_timestamp > self.lnworker.MPP_EXPIRY \ + or self.lnworker.stopping_soon: + _log_fail_reason(f"MPP TIMEOUT (> {self.lnworker.MPP_EXPIRY} sec)") + return OnionFailureCode.MPP_TIMEOUT, None, None + + if mpp_set.resolution == RecvMPPResolution.WAITING: + # check if set is first stage multi-trampoline payment to us + # first stage trampoline payment: + # is a trampoline payment + we_are_final + payment key is derived from outer onion's payment secret + # (so it is not the payment secret we requested in the invoice, but some secret set by a + # trampoline forwarding node on the route). + # if it is first stage, check if sum(htlcs) >= amount_to_forward of the trampoline_payload. + # If this part is complete, move the htlcs to the overall mpp set of the payment (keyed by inner secret). + # Once the second stage set (the set containing all htlcs of the separate trampoline parts) + # is complete, the payment gets fulfilled. + trampoline_payment_key = None + any_trampoline_onion = next(iter(processed_onions.values()))[1] + if any_trampoline_onion and any_trampoline_onion.are_we_final: + trampoline_payment_secret = any_trampoline_onion.payment_secret + assert trampoline_payment_secret == self.lnworker.get_payment_secret(payment_hash) + trampoline_payment_key = (payment_hash + trampoline_payment_secret).hex() + + if trampoline_payment_key and trampoline_payment_key != payment_key: + # first stage of trampoline payment, the first stage must never get set COMPLETE + if amount_msat >= any_trampoline_onion.amt_to_forward: + # setting the parent key will mark the htlcs to be moved to the parent set + self.logger.debug(f"trampoline part complete. {len(mpp_set.htlcs)=}, " + f"{amount_msat=}. setting parent key: {trampoline_payment_key}") + self.lnworker.received_mpp_htlcs[payment_key] = mpp_set._replace( + parent_set_key=trampoline_payment_key, + ) + elif amount_msat >= total_msat: + # set mpp_set as completed as we have received the full total_msat + mpp_set = self.lnworker.set_mpp_resolution( + payment_key=payment_key, + new_resolution=RecvMPPResolution.COMPLETE, + ) + + # check if this set is a trampoline forwarding and potentially return forwarding callback + # note: all inner trampoline onions are equal (enforced above) + _, any_inner_onion = next(iter(processed_onions.values())) + if any_inner_onion and not any_inner_onion.are_we_final: + # this is a trampoline forwarding + can_forward = mpp_set.resolution == RecvMPPResolution.COMPLETE and self.lnworker.enable_htlc_forwarding + if not can_forward: return None, None, None - else: - # HTLC we are supposed to forward, and have already forwarded - # for final trampoline onions, forwarding failures are stored with forwarding_key (which is the inner key) - payment_key = forwarding_key - preimage = self.lnworker.get_preimage(payment_hash) - error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key) - if error_bytes: - return None, None, error_bytes - if error_reason: - raise error_reason - if preimage: - return preimage, None, None + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.SETTLING) + fwd_cb = lambda: self.lnworker.maybe_forward_htlc_set(payment_key, processed_htlc_set=processed_onions) + return None, None, fwd_cb + + # -- from here on it's assumed this set is a payment for us (not something to forward) -- + payment_info = self.lnworker.get_payment_info(payment_hash) + if payment_info is None: + _log_fail_reason(f"payment info has been deleted") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + + # check invoice expiry, fail set if the invoice has expired before it was completed + if mpp_set.resolution == RecvMPPResolution.WAITING: + if int(time.time()) > payment_info.expiration_ts: + _log_fail_reason(f"invoice is expired {payment_info.expiration_ts=}") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None return None, None, None - def process_onion_packet( + if payment_hash.hex() in self.lnworker.dont_settle_htlcs: + # used by hold invoice cli to prevent the htlcs from getting fulfilled automatically + return None, None, None + + preimage = self.lnworker.get_preimage(payment_hash) + hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash) + if not preimage and not hold_invoice_callback: + _log_fail_reason(f"cannot settle, no preimage or callback found for {payment_hash.hex()=}") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + + if not self.lnworker.is_payment_bundle_complete(payment_key): + # don't allow settling before all sets of the bundle are COMPLETE + return None, None, None + else: + # If this set is part of a bundle now all parts are COMPLETE so the bundle can be deleted + # so the individual sets will get fulfilled. + self.lnworker.delete_payment_bundle(payment_key=bytes.fromhex(payment_key)) + + assert mpp_set.resolution == RecvMPPResolution.COMPLETE, "should return earlier if set is incomplete" + if not preimage: + assert hold_invoice_callback is not None, "should have been failed before" + async def callback(): + try: + await hold_invoice_callback(payment_hash) + except OnionRoutingFailure as e: # todo: should this catch all exceptions? + _log_fail_reason(f"hold invoice callback raised {e}") + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.FAILED) + # mpp set must not be failed unless the consumer calls unregister_hold_invoice and + # callback must only be called once. This is enforced by setting the set to SETTLING. + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.SETTLING) + return None, None, callback + + # settle htlc set + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.SETTLING) + return None, preimage, None + + def _check_final_mpp_set_state( + self, + payment_key: str, + mpp_set: ReceivedMPPStatus, + ) -> Optional[Tuple[ + Optional[Union[OnionRoutingFailure, OnionFailureCode, bytes]], # error types used to fail the set + Optional[bytes], # preimage to settle the set + None, # callback + ]]: + """ + handle sets that are already in a state eligible for fulfillment or failure and shouldn't + go through another iteration of _check_unfulfilled_htlc_set. + """ + if len(mpp_set.htlcs) == 0: + # stale set, will get deleted on the next iteration + return None, None, None + + if mpp_set.resolution == RecvMPPResolution.FAILED: + error_bytes, failure_message = self.lnworker.get_forwarding_failure(payment_key) + if error_bytes or failure_message: + return error_bytes or failure_message, None, None + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + elif mpp_set.resolution == RecvMPPResolution.EXPIRED: + return OnionFailureCode.MPP_TIMEOUT, None, None + + if mpp_set.parent_set_key: + # this is a complete trampoline part of a multi trampoline payment. Move the htlcs to parent. + parent = self.lnworker.received_mpp_htlcs.get(mpp_set.parent_set_key) + if not parent: + parent = ReceivedMPPStatus( + resolution=RecvMPPResolution.WAITING, + htlcs=set(), + ) + self.lnworker.received_mpp_htlcs[mpp_set.parent_set_key] = parent + parent.htlcs.update(mpp_set.htlcs) + mpp_set.htlcs.clear() + return None, None, None # this set will get deleted as there are no htlcs in it anymore + + assert not mpp_set.parent_set_key + if mpp_set.resolution == RecvMPPResolution.SETTLING: + # this is an ongoing forwarding, or a set that has not yet been fully settled (and removed). + # note the htlcs in SETTLING will not get failed automatically, + # even if timeout comes close, so either a forwarding failure or preimage has to be set + error_bytes, failure_message = self.lnworker.get_forwarding_failure(payment_key) + if error_bytes or failure_message: + # this was a forwarding set and it failed + self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.FAILED) + return error_bytes or failure_message, None, None + preimage = self.lnworker.get_preimage(mpp_set.get_payment_hash()) + return None, preimage, None + + return None + + def _parse_onion_packet(self, onion_packet_hex: str) -> OnionPacket: + """ + https://github.com/lightning/bolts/blob/14272b1bd9361750cfdb3e5d35740889a6b510b5/02-peer-protocol.md?plain=1#L2352 + """ + onion_packet_bytes = None + try: + onion_packet_bytes = bytes.fromhex(onion_packet_hex) + onion_packet = OnionPacket.from_bytes(onion_packet_bytes) + except Exception as parsing_exc: + self.logger.warning(f"unable to parse onion: {str(parsing_exc)}") + onion_parsing_error = OnionParsingError( + code=OnionFailureCodeMetaFlag.BADONION, + data=sha256(onion_packet_bytes or b''), + ) + raise onion_parsing_error + return onion_packet + + def _process_incoming_onion_packet( self, onion_packet: OnionPacket, *, payment_hash: bytes, - onion_packet_bytes: bytes, is_trampoline: bool = False) -> ProcessedOnionPacket: - - failure_data = sha256(onion_packet_bytes) + onion_hash = onion_packet.onion_hash + cache_key = sha256(onion_hash + payment_hash + bytes([is_trampoline])) # type: ignore + if cached_onion := self._processed_onion_cache.get(cache_key): + return cached_onion try: - processed_onion = process_onion_packet( + processed_onion = lnonion.process_onion_packet( onion_packet, our_onion_private_key=self.privkey, associated_data=payment_hash, is_trampoline=is_trampoline) + self._processed_onion_cache[cache_key] = processed_onion except UnsupportedOnionPacketVersion: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=failure_data) + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=onion_hash) except InvalidOnionPubkey: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_KEY, data=failure_data) + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_KEY, data=onion_hash) except InvalidOnionMac: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_HMAC, data=failure_data) + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_HMAC, data=onion_hash) except Exception as e: - self.logger.info(f"error processing onion packet: {e!r}") - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=failure_data) + self.logger.warning(f"error processing onion packet: {e!r}") + raise OnionParsingError(code=OnionFailureCodeMetaFlag.BADONION, data=onion_hash) if self.network.config.TEST_FAIL_HTLCS_AS_MALFORMED: - raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=failure_data) + raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_VERSION, data=onion_hash) if self.network.config.TEST_FAIL_HTLCS_WITH_TEMP_NODE_FAILURE: raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') return processed_onion diff --git a/electrum/lnsweep.py b/electrum/lnsweep.py index 014c87ad4..9fededb8a 100644 --- a/electrum/lnsweep.py +++ b/electrum/lnsweep.py @@ -445,6 +445,8 @@ def sweep_our_ctx( if not preimage: # we might not have the preimage if this is a hold invoice continue + if htlc.payment_hash in chan.lnworker.dont_settle_htlcs: + continue else: preimage = None try: @@ -746,6 +748,8 @@ def sweep_their_ctx( if not preimage: # we might not have the preimage if this is a hold invoice continue + if htlc.payment_hash in chan.lnworker.dont_settle_htlcs: + continue else: preimage = None tx_htlc( diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 76a6c876b..28f1f4e36 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -1934,26 +1934,86 @@ class UpdateAddHtlc: # Note: these states are persisted in the wallet file. # Do not modify them without performing a wallet db upgrade +# todo: if this changes again states could also be persisted by name instead of int value as done for ChannelState class RecvMPPResolution(IntEnum): - WAITING = 0 - EXPIRED = 1 - COMPLETE = 2 - FAILED = 3 + WAITING = 0 # set is not complete yet, waiting for arrival of the remaining htlcs + EXPIRED = 1 # preimage must not be revealed + COMPLETE = 2 # set is complete but could still be failed (e.g. due to cltv timeout) + FAILED = 3 # preimage must not be revealed + SETTLING = 4 # Must not be failed, should be settled asap. + # Also used when forwarding (for upstream), in which case a downstream + # forwarding failure could still result in transitioning to FAILED. + + +r = RecvMPPResolution +allowed_mpp_set_transitions = ( + (r.WAITING, r.EXPIRED), + (r.WAITING, r.FAILED), + (r.WAITING, r.COMPLETE), + (r.WAITING, r.SETTLING), # normal htlc forwarding + + (r.COMPLETE, r.SETTLING), + (r.COMPLETE, r.FAILED), + (r.COMPLETE, r.EXPIRED), # this should only realistically happen for payment bundles + + (r.SETTLING, r.FAILED), # forwarding failure, hold invoice callback gets unregistered, and we don't have preimage + + (r.EXPIRED, r.FAILED), # doesn't seem useful but also not dangerous +) +del r + + +class ReceivedMPPHtlc(NamedTuple): + scid: ShortChannelID + htlc: UpdateAddHtlc + unprocessed_onion: str + + def __repr__(self): + return f"{self.scid}, {self.htlc=}, {self.unprocessed_onion[:15]=}..." + + @staticmethod + def from_tuple(scid, htlc, unprocessed_onion) -> 'ReceivedMPPHtlc': + assert is_hex_str(unprocessed_onion) and is_hex_str(scid) + return ReceivedMPPHtlc( + scid=ShortChannelID(bytes.fromhex(scid)), + htlc=UpdateAddHtlc.from_tuple(*htlc), + unprocessed_onion=unprocessed_onion, + ) class ReceivedMPPStatus(NamedTuple): resolution: RecvMPPResolution - expected_msat: int - htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] + htlcs: set[ReceivedMPPHtlc] + # parent_set_key is needed as trampoline allows MPP to be nested, the parent_set_key is the + # payment key of the final mpp set (derived from inner trampoline onion payment secret) + # to which the separate trampoline sets htlcs get added once they are complete. + # https://github.com/lightning/bolts/pull/829/commits/bc7a1a0bc97b2293e7f43dd8a06529e5fdcf7cd2 + parent_set_key: str = None + + def get_first_htlc_timestamp(self) -> Optional[int]: + return min([mpp_htlc.htlc.timestamp for mpp_htlc in self.htlcs], default=None) + + def get_closest_cltv_abs(self) -> Optional[int]: + return min([mpp_htlc.htlc.cltv_abs for mpp_htlc in self.htlcs], default=None) + + def get_payment_hash(self) -> Optional[bytes]: + mpp_htlcs = iter(self.htlcs) + first_mpp_htlc = next(mpp_htlcs, None) + payment_hash = first_mpp_htlc.htlc.payment_hash if first_mpp_htlc else None + for mpp_htlc in mpp_htlcs: + assert mpp_htlc.htlc.payment_hash == payment_hash, "mpp set with inconsistent payment hashes" + return payment_hash @staticmethod @stored_in('received_mpp_htlcs', tuple) - def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus': - htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid, x) in htlc_list]) + def from_tuple(resolution, htlc_list, parent_set_key=None) -> 'ReceivedMPPStatus': + assert isinstance(resolution, int) + htlc_set = set(ReceivedMPPHtlc.from_tuple(*htlc_data) for htlc_data in htlc_list) return ReceivedMPPStatus( resolution=RecvMPPResolution(resolution), - expected_msat=expected_msat, - htlc_set=htlc_set) + htlcs=htlc_set, + parent_set_key=parent_set_key, + ) class OnionFailureCodeMetaFlag(IntFlag): diff --git a/electrum/lnworker.py b/electrum/lnworker.py index d3cf02e0a..d4c5bc8cb 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -10,7 +10,7 @@ import time from enum import IntEnum from typing import ( Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Mapping, Any, Iterable, AsyncGenerator, - Callable, Awaitable + Callable, Awaitable, Union, ) from types import MappingProxyType import threading @@ -70,7 +70,7 @@ from .lnutil import ( ShortChannelID, HtlcLog, NoPathFound, InvalidGossipMsg, FeeBudgetExceeded, ImportedChannelBackupStorage, OnchainChannelBackupStorage, ln_compare_features, IncompatibleLightningFeatures, PaymentFeeBudget, NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage, MIN_FUNDING_SAT, - MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE, ReceivedMPPStatus, RecvMPPResolution, + MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE, RecvMPPResolution, ReceivedMPPStatus, ReceivedMPPHtlc, PaymentSuccess, ) from .lnonion import ( @@ -1237,6 +1237,8 @@ class LNWallet(LNWorker): if type(chan) is Channel: self.save_channel(chan) self.clear_invoices_cache() + if chan._state == ChannelState.REDEEMED: + self.maybe_cleanup_mpp(chan) util.trigger_callback('channel', self.wallet, chan) def save_channel(self, chan: Channel): @@ -2399,15 +2401,47 @@ class LNWallet(LNWorker): self._payment_bundles_pkey_to_canon[pkey] = canon_pkey self._payment_bundles_canon_to_pkeylist[canon_pkey] = tuple(payment_keys) - def get_payment_bundle(self, payment_key: bytes) -> Sequence[bytes]: + def get_payment_bundle(self, payment_key: Union[bytes, str]) -> Sequence[bytes]: with self.lock: + if isinstance(payment_key, str): + try: + payment_key = bytes.fromhex(payment_key) + except ValueError: + # might be a forwarding payment_key which is not hex and will never have a bundle + return [] canon_pkey = self._payment_bundles_pkey_to_canon.get(payment_key) if canon_pkey is None: return [] return self._payment_bundles_canon_to_pkeylist[canon_pkey] - def delete_payment_bundle(self, payment_hash: bytes) -> None: - payment_key = self._get_payment_key(payment_hash) + def is_payment_bundle_complete(self, any_payment_key: str) -> bool: + """ + complete means a htlc set is available for each payment key of the payment bundle and + all htlc sets have a resolution >= COMPLETE (we got the whole payment bundle amount) + """ + # get all payment keys covered by this bundle + bundle_payment_keys = self.get_payment_bundle(any_payment_key) + if not bundle_payment_keys: # there is no payment bundle + return True + for payment_key in bundle_payment_keys: + mpp_set = self.received_mpp_htlcs.get(payment_key.hex()) + if mpp_set is None: + # payment bundle is missing htlc set for payment request + # it might have already been failed and deleted + return False + elif mpp_set.resolution not in (RecvMPPResolution.COMPLETE, RecvMPPResolution.SETTLING): + return False + return True + + def delete_payment_bundle( + self, *, + payment_hash: Optional[bytes] = None, + payment_key: Optional[bytes] = None, + ) -> None: + assert (payment_hash is not None) ^ (payment_key is not None), \ + "must provide exactly one of (payment_hash, payment_key)" + if not payment_key: + payment_key = self._get_payment_key(payment_hash) with self.lock: canon_pkey = self._payment_bundles_pkey_to_canon.get(payment_key) if canon_pkey is None: # is it ok for bundle to be missing?? @@ -2478,10 +2512,16 @@ class LNWallet(LNWorker): self.save_payment_info(info, write_to_disk=False) def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]): + assert self.get_preimage(payment_hash) is None, "hold invoice cb won't get called if preimage is already set" self.hold_invoice_callbacks[payment_hash] = cb def unregister_hold_invoice(self, payment_hash: bytes): - self.hold_invoice_callbacks.pop(payment_hash) + self.hold_invoice_callbacks.pop(payment_hash, None) + payment_key = self._get_payment_key(payment_hash).hex() + if payment_key in self.received_mpp_htlcs: + if self.get_preimage(payment_hash) is None: + # the pending mpp set can be failed as we don't have the preimage to settle it + self.set_mpp_resolution(payment_key, RecvMPPResolution.FAILED) def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None: assert info.status in SAVED_PR_STATUS @@ -2500,132 +2540,142 @@ class LNWallet(LNWorker): if write_to_disk: self.wallet.save_db() - def check_mpp_status( - self, *, - payment_secret: bytes, - short_channel_id: ShortChannelID, - htlc: UpdateAddHtlc, - expected_msat: int, - ) -> RecvMPPResolution: - """Returns the status of the incoming htlc set the given *htlc* belongs to. - - ACCEPTED simply means the mpp set is complete, and we can proceed with further - checks before fulfilling (or failing) the htlcs. - In particular, note that hold-invoice-htlcs typically remain in the ACCEPTED state - for quite some time -- not in the "WAITING" state (which would refer to the mpp set - not yet being complete!). - """ - payment_hash = htlc.payment_hash - payment_key = payment_hash + payment_secret - self.update_mpp_with_received_htlc( - payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat) - mpp_resolution = self.received_mpp_htlcs[payment_key.hex()].resolution - # if still waiting, calc resolution now: - if mpp_resolution == RecvMPPResolution.WAITING: - bundle = self.get_payment_bundle(payment_key) - if bundle: - payment_keys = bundle - else: - payment_keys = [payment_key] - first_timestamp = min([self.get_first_timestamp_of_mpp(pkey) for pkey in payment_keys]) - if self.get_payment_status(payment_hash) == PR_PAID: - mpp_resolution = RecvMPPResolution.COMPLETE - elif self.stopping_soon: - # try to time out pending HTLCs before shutting down - mpp_resolution = RecvMPPResolution.EXPIRED - elif all([self.is_mpp_amount_reached(pkey) for pkey in payment_keys]): - mpp_resolution = RecvMPPResolution.COMPLETE - elif time.time() - first_timestamp > self.MPP_EXPIRY: - mpp_resolution = RecvMPPResolution.EXPIRED - # save resolution, if any. - if mpp_resolution != RecvMPPResolution.WAITING: - for pkey in payment_keys: - if pkey.hex() in self.received_mpp_htlcs: - self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution) - - return mpp_resolution - - def update_mpp_with_received_htlc( + def update_or_create_mpp_with_received_htlc( self, *, - payment_key: bytes, + payment_key: str, scid: ShortChannelID, htlc: UpdateAddHtlc, - expected_msat: int, + unprocessed_onion_packet: str, ): - # add new htlc to set - mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) + # Payment key creation: + # * for regular forwarded htlcs -> "scid.hex() + ':%d' % htlc_id" [htlc key] + # * for trampoline forwarding -> "payment hash + payment secret from outer onion" + # * for final non-trampoline htlcs (we are receiver) -> "payment hash + payment secret from onion" + # * for final trampoline htlcs (we are receiver) -> 2. step grouping: + # 1. grouping of htlcs by "payments hash + outer onion payment secret", a 'multi-trampoline mpp part'. + # 2. once the set of step 1. is COMPLETE (amount_fwd outer onion >= total_amt outer onion) + # the htlcs get moved to the parent mpp set (created once first part is complete) grouped by: + # "payment_hash + inner onion payment secret (the one in the invoice)" + # After moving the htlcs the first set gets deleted. + # + # Add the validated htlc to the htlc set associated with the payment key. + # If no set exists, a new set in WAITING state is created. + mpp_status = self.received_mpp_htlcs.get(payment_key) if mpp_status is None: + self.logger.debug(f"creating new mpp set for {payment_key=}") mpp_status = ReceivedMPPStatus( resolution=RecvMPPResolution.WAITING, - expected_msat=expected_msat, - htlc_set=set(), + htlcs=set(), ) - if expected_msat != mpp_status.expected_msat: - self.logger.info( - f"marking received mpp as failed. inconsistent total_msats in bucket. {payment_key.hex()=}") - mpp_status = mpp_status._replace(resolution=RecvMPPResolution.FAILED) - key = (scid, htlc) - if key not in mpp_status.htlc_set: - mpp_status.htlc_set.add(key) # side-effecting htlc_set - self.received_mpp_htlcs[payment_key.hex()] = mpp_status - def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution): - mpp_status = self.received_mpp_htlcs[payment_key.hex()] - self.logger.info(f'set_mpp_resolution {resolution.name} {len(mpp_status.htlc_set)} {payment_key.hex()}') - self.received_mpp_htlcs[payment_key.hex()] = mpp_status._replace(resolution=resolution) + if mpp_status.resolution > RecvMPPResolution.WAITING: + # we are getting a htlc for a set that is not in WAITING state, it cannot be safely added + self.logger.info(f"htlc set cannot accept htlc, failing htlc: {scid=} {htlc.htlc_id=}") + if mpp_status == RecvMPPResolution.EXPIRED: + raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') + raise OnionRoutingFailure( + code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, + data=htlc.amount_msat.to_bytes(8, byteorder="big"), + ) - def is_mpp_amount_reached(self, payment_key: bytes) -> bool: - amounts = self.get_mpp_amounts(payment_key) - if amounts is None: - return False - total, expected = amounts - return total >= expected + new_htlc = ReceivedMPPHtlc( + scid=scid, + htlc=htlc, + unprocessed_onion=unprocessed_onion_packet, + ) + assert new_htlc not in mpp_status.htlcs, "each htlc should make it here only once?" + assert isinstance(unprocessed_onion_packet, str) + mpp_status.htlcs.add(new_htlc) # side-effecting htlc_set + self.received_mpp_htlcs[payment_key] = mpp_status - def is_complete_mpp(self, payment_hash: bytes) -> bool: + def set_mpp_resolution(self, payment_key: str, new_resolution: RecvMPPResolution) -> ReceivedMPPStatus: + mpp_status = self.received_mpp_htlcs[payment_key] + if mpp_status.resolution == new_resolution: + return mpp_status + if not (mpp_status.resolution, new_resolution) in lnutil.allowed_mpp_set_transitions: + raise ValueError(f'forbidden mpp set transition: {mpp_status.resolution} -> {new_resolution}') + self.logger.info(f'set_mpp_resolution {new_resolution.name} {len(mpp_status.htlcs)=}: {payment_key=}') + self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=new_resolution) + self.wallet.save_db() + return self.received_mpp_htlcs[payment_key] + + def set_htlc_set_error( + self, + payment_key: str, + error: Union[bytes, OnionFailureCode, OnionRoutingFailure], + ) -> Optional[Tuple[Optional[bytes], Optional[OnionFailureCode | int], Optional[bytes]]]: + """ + handles different types of errors and sets the htlc set to failed, then returns a more + structured tuple of error types which can then be used to fail the htlc set + """ + htlc_set = self.received_mpp_htlcs[payment_key] + assert htlc_set.resolution != RecvMPPResolution.SETTLING + raw_error, error_code, error_data = None, None, None + if isinstance(error, bytes): + raw_error = error + elif isinstance(error, OnionFailureCode): + error_code = error + elif isinstance(error, OnionRoutingFailure): + error_code, error_data = OnionFailureCode.from_int(error.code), error.data + else: + raise ValueError(f"invalid error type: {repr(error)}") + + if error_code == OnionFailureCode.MPP_TIMEOUT: + self.set_mpp_resolution(payment_key=payment_key, new_resolution=RecvMPPResolution.EXPIRED) + else: + self.set_mpp_resolution(payment_key=payment_key, new_resolution=RecvMPPResolution.FAILED) + + return raw_error, error_code, error_data + + def get_mpp_resolution(self, payment_hash: bytes) -> Optional[RecvMPPResolution]: payment_key = self._get_payment_key(payment_hash) status = self.received_mpp_htlcs.get(payment_key.hex()) - return status and status.resolution == RecvMPPResolution.COMPLETE + return status.resolution if status else None + + def is_complete_mpp(self, payment_hash: bytes) -> bool: + resolution = self.get_mpp_resolution(payment_hash) + if resolution is not None: + return resolution in (RecvMPPResolution.COMPLETE, RecvMPPResolution.SETTLING) + return False def get_payment_mpp_amount_msat(self, payment_hash: bytes) -> Optional[int]: """Returns the received mpp amount for given payment hash.""" payment_key = self._get_payment_key(payment_hash) - amounts = self.get_mpp_amounts(payment_key) - if not amounts: + total_msat = self.get_mpp_amounts(payment_key) + if not total_msat: return None - total_msat, _ = amounts return total_msat - def get_mpp_amounts(self, payment_key: bytes) -> Optional[Tuple[int, int]]: - """Returns (total received amount, expected amount) or None.""" + def get_mpp_amounts(self, payment_key: bytes) -> Optional[int]: + """Returns total received amount or None.""" mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) if not mpp_status: return None - total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set]) - return total, mpp_status.expected_msat - - def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int: - mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) - if not mpp_status: - return int(time.time()) - return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set]) + total = sum([mpp_htlc.htlc.amount_msat for mpp_htlc in mpp_status.htlcs]) + return total def maybe_cleanup_mpp( self, - short_channel_id: ShortChannelID, - htlc: UpdateAddHtlc, + chan: Channel, ) -> None: - - htlc_key = (short_channel_id, htlc) + """ + Remove all remaining mpp htlcs of the given channel after closing. + Usually they get removed in htlc_switch after all htlcs of the set are resolved, + however if there is a force close with pending htlcs they need to be removed after the channel + is closed. + """ + # only cleanup when channel is REDEEMED as mpp set is still required for lnsweep + assert chan._state == ChannelState.REDEEMED for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()): - if htlc_key not in mpp_status.htlc_set: - continue - assert mpp_status.resolution != RecvMPPResolution.WAITING - self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}') - mpp_status.htlc_set.remove(htlc_key) # side-effecting htlc_set - if len(mpp_status.htlc_set) == 0: + htlcs_to_remove = [htlc for htlc in mpp_status.htlcs if htlc.scid == chan.short_channel_id] + for stale_mpp_htlc in htlcs_to_remove: + assert mpp_status.resolution != RecvMPPResolution.WAITING + self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}') + mpp_status.htlcs.remove(stale_mpp_htlc) # side-effecting htlc_set + if len(mpp_status.htlcs) == 0: self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}') - self.received_mpp_htlcs.pop(payment_key_hex) + del self.received_mpp_htlcs[payment_key_hex] self.maybe_cleanup_forwarding(payment_key_hex) def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None: @@ -2681,6 +2731,7 @@ class LNWallet(LNWorker): for payment_key, htlcs in self.active_forwardings.items(): if htlc_key in htlcs: return payment_key + return None def notify_upstream_peer(self, htlc_key: str) -> None: """Called when an HTLC we offered on chan gets irrevocably fulfilled or failed. @@ -3436,7 +3487,58 @@ class LNWallet(LNWorker): util.trigger_callback('channels_updated', self.wallet) self.lnwatcher.add_channel(cb) - async def maybe_forward_htlc( + async def maybe_forward_htlc_set( + self, + payment_key: str, *, + processed_htlc_set: dict[ReceivedMPPHtlc, Tuple[ProcessedOnionPacket, Optional[ProcessedOnionPacket]]], + ) -> None: + assert self.enable_htlc_forwarding + assert payment_key not in self.active_forwardings, "cannot forward set twice" + self.active_forwardings[payment_key] = [] + self.logger.debug(f"adding active_forwarding: {payment_key=}") + + any_mpp_htlc, (any_outer_onion, any_trampoline_onion) = next(iter(processed_htlc_set.items())) + try: + if any_trampoline_onion is None: + assert not any_outer_onion.are_we_final + assert len(processed_htlc_set) == 1, processed_htlc_set + forward_htlc = any_mpp_htlc.htlc + incoming_chan = self.get_channel_by_short_id(any_mpp_htlc.scid) + next_htlc = await self._maybe_forward_htlc( + incoming_chan=incoming_chan, + htlc=forward_htlc, + processed_onion=any_outer_onion, + ) + htlc_key = serialize_htlc_key(incoming_chan.get_scid_or_local_alias(), forward_htlc.htlc_id) + self.active_forwardings[payment_key].append(next_htlc) + self.downstream_to_upstream_htlc[next_htlc] = htlc_key + else: + assert not any_trampoline_onion.are_we_final and any_outer_onion.are_we_final + # trampoline forwarding + min_inc_cltv_abs = min( + mpp_htlc.htlc.cltv_abs + for mpp_htlc in processed_htlc_set.keys()) # take "min" to assume worst-case + await self._maybe_forward_trampoline( + payment_hash=any_mpp_htlc.htlc.payment_hash, + closest_inc_cltv_abs=min_inc_cltv_abs, + total_msat=any_outer_onion.total_msat, + any_trampoline_onion=any_trampoline_onion, + fw_payment_key=payment_key, + ) + except OnionRoutingFailure as e: + self.logger.debug(f"forwarding failed: {e=}") + if len(self.active_forwardings[payment_key]) == 0: + self.save_forwarding_failure(payment_key, failure_message=e) + # TODO what about other errors? + # Could we "catch-all Exception" and fail back the htlcs with e.g. TEMPORARY_NODE_FAILURE? + # - we don't want to fail the inc-HTLC for a syntax error that happens in the callback + # If we don't call save_forwarding_failure(), the inc-HTLC gets stuck until expiry + # and then the inc-channel will get force-closed. + # => forwarding_callback() could have an API with two exceptions types: + # - type1, such as OnionRoutingFailure, that signals we need to fail back the inc-HTLC + # - type2, such as NoPathFound, that signals we want to retry forwarding + + async def _maybe_forward_htlc( self, *, incoming_chan: Channel, htlc: UpdateAddHtlc, @@ -3550,13 +3652,12 @@ class LNWallet(LNWorker): htlc_key = serialize_htlc_key(next_chan.get_scid_or_local_alias(), next_htlc.htlc_id) return htlc_key - @log_exceptions - async def maybe_forward_trampoline( + async def _maybe_forward_trampoline( self, *, payment_hash: bytes, - inc_cltv_abs: int, - outer_onion: ProcessedOnionPacket, - trampoline_onion: ProcessedOnionPacket, + closest_inc_cltv_abs: int, + total_msat: int, # total_msat of the outer onion + any_trampoline_onion: ProcessedOnionPacket, # any trampoline onion of the incoming htlc set, they should be similar fw_payment_key: str, ) -> None: @@ -3565,7 +3666,7 @@ class LNWallet(LNWorker): if not (forwarding_enabled and forwarding_trampoline_enabled): self.logger.info(f"trampoline forwarding is disabled. failing htlc.") raise OnionRoutingFailure(code=OnionFailureCode.PERMANENT_CHANNEL_FAILURE, data=b'') - payload = trampoline_onion.hop_data.payload + payload = any_trampoline_onion.hop_data.payload payment_data = payload.get('payment_data') try: payment_secret = payment_data['payment_secret'] if payment_data else os.urandom(32) @@ -3583,7 +3684,7 @@ class LNWallet(LNWorker): else: self.logger.info('forward_trampoline: end-to-end') invoice_features = LnFeatures.BASIC_MPP_OPT - next_trampoline_onion = trampoline_onion.next_packet + next_trampoline_onion = any_trampoline_onion.next_packet r_tags = [] except Exception as e: self.logger.exception('') @@ -3597,13 +3698,12 @@ class LNWallet(LNWorker): # these are the fee/cltv paid by the sender # pay_to_node will raise if they are not sufficient - total_msat = outer_onion.hop_data.payload["payment_data"]["total_msat"] budget = PaymentFeeBudget( fee_msat=total_msat - amt_to_forward, - cltv=inc_cltv_abs - out_cltv_abs, + cltv=closest_inc_cltv_abs - out_cltv_abs, ) self.logger.info(f'trampoline forwarding. budget={budget}') - self.logger.info(f'trampoline forwarding. {inc_cltv_abs=}, {out_cltv_abs=}') + self.logger.info(f'trampoline forwarding. {closest_inc_cltv_abs=}, {out_cltv_abs=}') # To convert abs vs rel cltvs, we need to guess blockheight used by original sender as "current blockheight". # Blocks might have been mined since. # - if we skew towards the past, we decrease our own cltv_budget accordingly (which is ok) @@ -3715,7 +3815,6 @@ class LNWallet(LNWorker): final_cltv_abs=final_cltv_abs, total_msat=total_msat, payment_secret=payment_secret) - num_hops = len(hops_data) self.logger.info(f"pay len(route)={len(route)}. for payment_hash={payment_hash.hex()}") for i in range(len(route)): self.logger.info(f" {i}: edge={route[i].short_channel_id} hop_data={hops_data[i]!r}") diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index 3705aa572..de6a40a54 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -37,8 +37,7 @@ from .util import ( run_sync_function_on_asyncio_thread, trigger_callback, NoDynamicFeeEstimates, UserFacingException, ) from . import lnutil -from .lnutil import (hex_to_bytes, REDEEM_AFTER_DOUBLE_SPENT_DELAY, Keypair, - MIN_FINAL_CLTV_DELTA_ACCEPTED) +from .lnutil import hex_to_bytes, REDEEM_AFTER_DOUBLE_SPENT_DELAY, Keypair from .lnaddr import lndecode from .json_db import StoredObject, stored_in from . import constants @@ -257,7 +256,7 @@ class SwapManager(Logger): payment_hash = bytes.fromhex(payment_hash_hex) swap._payment_hash = payment_hash self._add_or_reindex_swap(swap, is_new=False) - if not swap.is_reverse and not swap.is_redeemed: + if not swap.is_reverse and not swap.is_redeemed and not self.lnworker.get_preimage(swap.payment_hash): self.lnworker.register_hold_invoice(payment_hash, self.hold_invoice_callback) self._prepayments = {} # type: Dict[bytes, bytes] # fee_rhash -> rhash @@ -399,8 +398,8 @@ class SwapManager(Logger): def _fail_swap(self, swap: SwapData, reason: str): self.logger.info(f'failing swap {swap.payment_hash.hex()}: {reason}') if not swap.is_reverse and swap.payment_hash in self.lnworker.hold_invoice_callbacks: + # unregister_hold_invoice will fail pending htlcs if there is no preimage available self.lnworker.unregister_hold_invoice(swap.payment_hash) - # Peer.maybe_fulfill_htlc will fail incoming htlcs if there is no payment info self.lnworker.delete_payment_info(swap.payment_hash.hex()) self.lnworker.clear_invoices_cache() self.lnwatcher.remove_callback(swap.lockup_address) @@ -415,7 +414,7 @@ class SwapManager(Logger): self._prepayments.pop(swap.prepay_hash, None) if self.lnworker.get_payment_status(swap.prepay_hash) != PR_PAID: self.lnworker.delete_payment_info(swap.prepay_hash.hex()) - self.lnworker.delete_payment_bundle(swap.payment_hash) + self.lnworker.delete_payment_bundle(payment_hash=swap.payment_hash) @classmethod def extract_preimage(cls, swap: SwapData, claim_tx: Transaction) -> Optional[bytes]: @@ -473,7 +472,7 @@ class SwapManager(Logger): # cleanup self.lnwatcher.remove_callback(swap.lockup_address) if not swap.is_reverse: - self.lnworker.delete_payment_bundle(swap.payment_hash) + self.lnworker.delete_payment_bundle(payment_hash=swap.payment_hash) self.lnworker.unregister_hold_invoice(swap.payment_hash) if not swap.is_reverse: @@ -690,7 +689,7 @@ class SwapManager(Logger): self.lnworker.add_payment_info_for_hold_invoice( payment_hash, lightning_amount_sat=invoice_amount_sat, - min_final_cltv_delta=min_final_cltv_expiry_delta or MIN_FINAL_CLTV_DELTA_ACCEPTED, + min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED, exp_delay=300, ) info = self.lnworker.get_payment_info(payment_hash) @@ -709,7 +708,7 @@ class SwapManager(Logger): if prepay: prepay_hash = self.lnworker.create_payment_info( amount_msat=prepay_amount_sat*1000, - min_final_cltv_delta=min_final_cltv_expiry_delta or MIN_FINAL_CLTV_DELTA_ACCEPTED, + min_final_cltv_delta=min_final_cltv_expiry_delta or lnutil.MIN_FINAL_CLTV_DELTA_ACCEPTED, exp_delay=300, ) info = self.lnworker.get_payment_info(prepay_hash) diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index 09cf9d948..bf22f03fb 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -22,33 +22,29 @@ # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import os -import ast import datetime import json import copy -import threading from collections import defaultdict from typing import (Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING, Union, AbstractSet) -import binascii import time from functools import partial import attr -from . import util, bitcoin -from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, bfh, MyEncoder -from .invoices import Invoice, Request +from . import bitcoin +from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, MyEncoder from .keystore import bip44_derivation from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput, BadHeaderMagic from .logging import Logger -from .lnutil import HTLCOwner, ChannelType +from .lnutil import HTLCOwner, ChannelType, RecvMPPResolution from . import json_db -from .json_db import StoredDict, JsonDB, locked, modifier, StoredObject, stored_in, stored_as +from .json_db import JsonDB, locked, modifier, StoredObject, stored_in, stored_as from .plugin import run_hook, plugin_loaders from .version import ELECTRUM_VERSION +from .i18n import _ if TYPE_CHECKING: from .storage import WalletStorage @@ -73,7 +69,7 @@ class WalletUnfinished(WalletFileException): # seed_version is now used for the version of the wallet file OLD_SEED_VERSION = 4 # electrum versions < 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0 -FINAL_SEED_VERSION = 62 # electrum >= 2.7 will set this to prevent +FINAL_SEED_VERSION = 63 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format @@ -238,6 +234,7 @@ class WalletDBUpgrader(Logger): self._convert_version_60() self._convert_version_61() self._convert_version_62() + self._convert_version_63() self.put('seed_version', FINAL_SEED_VERSION) # just to be sure def _convert_wallet_type(self): @@ -1182,6 +1179,96 @@ class WalletDBUpgrader(Logger): swap['claim_to_output'] = None self.data['seed_version'] = 62 + def _convert_version_63(self): + if not self._is_upgrade_method_needed(62, 62): + return + # Old ReceivedMPPStatus: + # class ReceivedMPPStatus(NamedTuple): + # resolution: RecvMPPResolution + # expected_msat: int + # htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] + # + # New ReceivedMPPStatus: + # class ReceivedMPPStatus(NamedTuple): + # resolution: RecvMPPResolution + # htlcs: set[ReceivedMPPHtlc] + # + # class ReceivedMPPHtlc(NamedTuple): + # scid: ShortChannelID + # htlc: UpdateAddHtlc + # unprocessed_onion: str + + # previously chan.unfulfilled_htlcs went through 4 stages: + # - 1. not forwarded yet: (onion_packet_hex, None) + # - 2. forwarded: (onion_packet_hex, forwarding_key) + # - 3. processed: (None, forwarding_key), not irrevocably removed yet + # - 4. done: (None, forwarding_key), irrevocably removed + channels = self.data.get('channels', {}) + def _move_unprocessed_onion(short_channel_id: str, htlc_id: Optional[int]) -> Optional[Tuple[str, Optional[str]]]: + if htlc_id is None: + return None + for chan_ in channels.values(): + if chan_['short_channel_id'] != short_channel_id: + continue + unfulfilled_htlcs_ = chan_.get('unfulfilled_htlcs', {}) + htlc_data = unfulfilled_htlcs_.get(str(htlc_id)) + if htlc_data is None: + return None + stored_onion_packet, htlc_forwarding_key = htlc_data + if stored_onion_packet is not None: + htlc_data[0] = None # overwrite the onion so it is not processed again in htlc_switch + return stored_onion_packet, htlc_forwarding_key + return None + + mpp_sets = self.data.get('received_mpp_htlcs', {}) + for payment_key, recv_mpp_status in list(mpp_sets.items()): + assert isinstance(recv_mpp_status, list), f"{recv_mpp_status=}" + del recv_mpp_status[1] # remove expected_msat + + new_type_htlcs = [] + forwarding_key = None + for scid, update_add_htlc in recv_mpp_status[1]: # htlc_set + htlc_info_from_chan = _move_unprocessed_onion(scid, update_add_htlc[3]) + if htlc_info_from_chan is None: + # if there is no onion packet for the htlc it is dropped as it was already + # processed in the old htlc_switch + continue + onion_packet_hex = htlc_info_from_chan[0] + forwarding_key = htlc_info_from_chan[1] if htlc_info_from_chan[1] else forwarding_key + new_type_htlcs.append([ + scid, + update_add_htlc, + onion_packet_hex, + ]) + + if len(new_type_htlcs) == 0: + self.logger.debug(f"_convert_version_62: dropping mpp set {payment_key=}.") + del mpp_sets[payment_key] + else: + recv_mpp_status[1] = new_type_htlcs + self.logger.debug(f"_convert_version_62: migrated mpp set {payment_key=}") + if forwarding_key is not None: + # if the forwarding key is set for the old mpp set it was either a forwarding + # or a swap hold invoice. Assuming users of 4.6.2 don't use forwarding this update + # most likely happens during a swap waiting for the preimage. Setting the mpp set + # to SETTLING prevents us from accidentally failing the htlc set after the update, + # however it carries the risk of the channel getting force closed if the swap fails + # as the htlcs won't get failed due to the new SETTLING state + # unless a forwarding error is set. + recv_mpp_status[0] = 4 # RecvMPPResolution.SETTLING + + # replace Tuple[onion, forwarding_key] with just the onion in chan['unfulfilled_htlcs'] + for chan in channels.values(): + unfulfilled_htlcs = chan.get('unfulfilled_htlcs', {}) + for htlc_id, (unprocessed_onion, forwarding_key) in list(unfulfilled_htlcs.items()): + if unprocessed_onion is None: + # delete all unfulfilled_htlcs with empty onion as they are already processed + del unfulfilled_htlcs[htlc_id] + else: + unfulfilled_htlcs[htlc_id] = unprocessed_onion + + self.data['seed_version'] = 63 + def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): return diff --git a/tests/test_commands.py b/tests/test_commands.py index 28a80a607..455f7b87e 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -549,13 +549,13 @@ class TestCommandsTestnet(ElectrumTestCase): ) mock_htlc1 = mock.Mock() - mock_htlc1.cltv_abs = 800_000 - mock_htlc1.amount_msat = 4_500_000 + mock_htlc1.htlc.cltv_abs = 800_000 + mock_htlc1.htlc.amount_msat = 4_500_000 mock_htlc2 = mock.Mock() - mock_htlc2.cltv_abs = 800_144 - mock_htlc2.amount_msat = 5_500_000 + mock_htlc2.htlc.cltv_abs = 800_144 + mock_htlc2.htlc.amount_msat = 5_500_000 mock_htlc_status = mock.Mock() - mock_htlc_status.htlc_set = [(None, mock_htlc1), (None, mock_htlc2)] + mock_htlc_status.htlcs = [mock_htlc1, mock_htlc2] mock_htlc_status.resolution = RecvMPPResolution.COMPLETE payment_key = wallet.lnworker._get_payment_key(bytes.fromhex(payment_hash)).hex() diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 360ede67e..fbd148f70 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -301,7 +301,6 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): set_request_status = LNWallet.set_request_status set_payment_status = LNWallet.set_payment_status get_payment_status = LNWallet.get_payment_status - check_mpp_status = LNWallet.check_mpp_status htlc_fulfilled = LNWallet.htlc_fulfilled htlc_failed = LNWallet.htlc_failed save_preimage = LNWallet.save_preimage @@ -334,11 +333,9 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): unregister_hold_invoice = LNWallet.unregister_hold_invoice add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice - update_mpp_with_received_htlc = LNWallet.update_mpp_with_received_htlc + update_or_create_mpp_with_received_htlc = LNWallet.update_or_create_mpp_with_received_htlc set_mpp_resolution = LNWallet.set_mpp_resolution - is_mpp_amount_reached = LNWallet.is_mpp_amount_reached get_mpp_amounts = LNWallet.get_mpp_amounts - get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp bundle_payments = LNWallet.bundle_payments get_payment_bundle = LNWallet.get_payment_bundle _get_payment_key = LNWallet._get_payment_key @@ -347,11 +344,14 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding current_target_feerate_per_kw = LNWallet.current_target_feerate_per_kw current_low_feerate_per_kw_srk_channel = LNWallet.current_low_feerate_per_kw_srk_channel - maybe_cleanup_mpp = LNWallet.maybe_cleanup_mpp create_onion_for_route = LNWallet.create_onion_for_route - maybe_forward_htlc = LNWallet.maybe_forward_htlc - maybe_forward_trampoline = LNWallet.maybe_forward_trampoline + maybe_forward_htlc_set = LNWallet.maybe_forward_htlc_set + _maybe_forward_htlc = LNWallet._maybe_forward_htlc + _maybe_forward_trampoline = LNWallet._maybe_forward_trampoline _maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created = LNWallet._maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created + set_htlc_set_error = LNWallet.set_htlc_set_error + is_payment_bundle_complete = LNWallet.is_payment_bundle_complete + delete_payment_bundle = LNWallet.delete_payment_bundle _process_htlc_log = LNWallet._process_htlc_log