From 8a88ebe6bcb4e3d23a9b1a90a1611000d5a4720b Mon Sep 17 00:00:00 2001 From: f321x Date: Mon, 8 Dec 2025 11:00:47 +0100 Subject: [PATCH 1/6] lnworker: add type assert to get_channel_by_short_id Prevents accidentally passing None if channel.short_id is not set yet --- electrum/lnworker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index c49aa3dfc..f28c99a16 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1574,6 +1574,7 @@ class LNWallet(LNWorker): return chan, funding_tx def get_channel_by_short_id(self, short_channel_id: bytes) -> Optional[Channel]: + assert short_channel_id and isinstance(short_channel_id, bytes), repr(short_channel_id) # First check against *real* SCIDs. # This e.g. protects against maliciously chosen SCID aliases, and accidental collisions. for chan in self.channels.values(): From 5be598b808457ad917242053a081551a1261521d Mon Sep 17 00:00:00 2001 From: f321x Date: Mon, 8 Dec 2025 12:00:52 +0100 Subject: [PATCH 2/6] lnworker: use channel_id instead of scid in ReceivedMPPHtlc Store the channel id instead of the scid in ReceivedMPPHtlc. The scid can be None, in theory even for multiple channels at the same time. Using the channel_id which is always available and unique seems less error prone at the cost of temporarily higher storage requirements in the db for the duration of the pending htlcs. Alternatively we could use the local scid alias however using the channel_id seems less complex and leaves less room for ambiguity. --- electrum/lnpeer.py | 15 ++++++++------- electrum/lnutil.py | 10 +++++----- electrum/lnworker.py | 10 +++++----- electrum/wallet_db.py | 29 ++++++++++++++++++++++++++++- 4 files changed, 46 insertions(+), 18 deletions(-) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index eaaddacfa..ec1362944 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2150,7 +2150,7 @@ class Peer(Logger, EventListener): Does additional checks on the incoming htlc and return the payment key if the tests pass, otherwise raises OnionRoutingError which will get the htlc failed. """ - _log_fail_reason = self._log_htlc_fail_reason_cb(chan.short_channel_id, htlc, processed_onion.hop_data.payload) + _log_fail_reason = self._log_htlc_fail_reason_cb(chan.channel_id, htlc, processed_onion.hop_data.payload) # Check that our blockchain tip is sufficiently recent so that we have an approx idea of the height. # We should not release the preimage for an HTLC that its sender could already time out as @@ -2269,7 +2269,7 @@ class Peer(Logger, EventListener): self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None) # htlcs wont get expired anymore for mpp_htlc in list(htlc_set.htlcs): htlc_id = mpp_htlc.htlc.htlc_id - chan = self.lnworker.get_channel_by_short_id(mpp_htlc.scid) + chan = self.lnworker.channels[mpp_htlc.channel_id] if chan.channel_id not in self.channels: # this htlc belongs to another peer and has to be settled in their htlc_switch continue @@ -2311,7 +2311,7 @@ class Peer(Logger, EventListener): self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None) self.lnworker.dont_settle_htlcs.pop(payment_hash.hex(), None) # already failed for mpp_htlc in list(htlc_set.htlcs): - chan = self.lnworker.get_channel_by_short_id(mpp_htlc.scid) + chan = self.lnworker.channels[mpp_htlc.channel_id] htlc_id = mpp_htlc.htlc.htlc_id if chan.channel_id not in self.channels: # this htlc belongs to another peer and has to be settled in their htlc_switch @@ -2854,7 +2854,7 @@ class Peer(Logger, EventListener): ) self.lnworker.update_or_create_mpp_with_received_htlc( payment_key=payment_key, - scid=chan.short_channel_id, + channel_id=chan.channel_id, htlc=htlc, unprocessed_onion_packet=onion_packet_hex, # outer onion if trampoline ) @@ -2938,11 +2938,12 @@ class Peer(Logger, EventListener): def _log_htlc_fail_reason_cb( self, - scid: ShortChannelID, + channel_id: bytes, htlc: UpdateAddHtlc, onion_payload: dict ) -> Callable[[str], None]: def _log_fail_reason(reason: str) -> None: + scid = self.lnworker.channels[channel_id].short_channel_id self.logger.info(f"will FAIL HTLC: {str(scid)=}. {reason=}. {str(htlc)=}. {onion_payload=}") return _log_fail_reason @@ -2960,7 +2961,7 @@ class Peer(Logger, EventListener): onion_payload = {} self._log_htlc_fail_reason_cb( - mpp_htlc.scid, + mpp_htlc.channel_id, mpp_htlc.htlc, onion_payload, )(f"mpp set {id(mpp_set)} failed: {reason}") @@ -3074,7 +3075,7 @@ class Peer(Logger, EventListener): if mpp_set.resolution == RecvMPPResolution.WAITING: # calculate the sum of just in time channel opening fees - htlc_channels = [self.lnworker.get_channel_by_short_id(scid) for scid in set(h.scid for h in mpp_set.htlcs)] + htlc_channels = [self.lnworker.channels[channel_id] for channel_id in set(h.channel_id for h in mpp_set.htlcs)] jit_opening_fees_msat = sum((c.jit_opening_fee or 0) for c in htlc_channels) # check if set is first stage multi-trampoline payment to us diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 70256a74d..f670b4bb2 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -1965,18 +1965,18 @@ del r class ReceivedMPPHtlc(NamedTuple): - scid: ShortChannelID + channel_id: bytes htlc: UpdateAddHtlc unprocessed_onion: str def __repr__(self): - return f"{self.scid}, {self.htlc=}, {self.unprocessed_onion[:15]=}..." + return f"chan_id={self.channel_id.hex()}, {self.htlc=}, {self.unprocessed_onion[:15]=}..." @staticmethod - def from_tuple(scid, htlc, unprocessed_onion) -> 'ReceivedMPPHtlc': - assert is_hex_str(unprocessed_onion) and is_hex_str(scid) + def from_tuple(channel_id, htlc, unprocessed_onion) -> 'ReceivedMPPHtlc': + assert is_hex_str(unprocessed_onion) and is_hex_str(channel_id) return ReceivedMPPHtlc( - scid=ShortChannelID(bytes.fromhex(scid)), + channel_id=bytes.fromhex(channel_id), htlc=UpdateAddHtlc.from_tuple(*htlc), unprocessed_onion=unprocessed_onion, ) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index f28c99a16..1db9d0de0 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -2583,7 +2583,7 @@ class LNWallet(LNWorker): self, *, payment_key: str, - scid: ShortChannelID, + channel_id: bytes, htlc: UpdateAddHtlc, unprocessed_onion_packet: str, ): @@ -2610,7 +2610,7 @@ class LNWallet(LNWorker): if mpp_status.resolution > RecvMPPResolution.WAITING: # we are getting a htlc for a set that is not in WAITING state, it cannot be safely added - self.logger.info(f"htlc set cannot accept htlc, failing htlc: {scid=} {htlc.htlc_id=}") + self.logger.info(f"htlc set cannot accept htlc, failing htlc: {channel_id=} {htlc.htlc_id=}") if mpp_status == RecvMPPResolution.EXPIRED: raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') raise OnionRoutingFailure( @@ -2619,7 +2619,7 @@ class LNWallet(LNWorker): ) new_htlc = ReceivedMPPHtlc( - scid=scid, + channel_id=channel_id, htlc=htlc, unprocessed_onion=unprocessed_onion_packet, ) @@ -2707,7 +2707,7 @@ class LNWallet(LNWorker): # only cleanup when channel is REDEEMED as mpp set is still required for lnsweep assert chan._state == ChannelState.REDEEMED for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()): - htlcs_to_remove = [htlc for htlc in mpp_status.htlcs if htlc.scid == chan.short_channel_id] + htlcs_to_remove = [htlc for htlc in mpp_status.htlcs if htlc.channel_id == chan.channel_id] for stale_mpp_htlc in htlcs_to_remove: assert mpp_status.resolution != RecvMPPResolution.WAITING self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}') @@ -3547,7 +3547,7 @@ class LNWallet(LNWorker): assert not any_outer_onion.are_we_final assert len(processed_htlc_set) == 1, processed_htlc_set forward_htlc = any_mpp_htlc.htlc - incoming_chan = self.get_channel_by_short_id(any_mpp_htlc.scid) + incoming_chan = self.channels[any_mpp_htlc.channel_id] next_htlc = await self._maybe_forward_htlc( incoming_chan=incoming_chan, htlc=forward_htlc, diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index 325eb36f3..068499c2d 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -69,7 +69,7 @@ class WalletUnfinished(WalletFileException): # seed_version is now used for the version of the wallet file OLD_SEED_VERSION = 4 # electrum versions < 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0 -FINAL_SEED_VERSION = 64 # electrum >= 2.7 will set this to prevent +FINAL_SEED_VERSION = 65 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format @@ -236,6 +236,7 @@ class WalletDBUpgrader(Logger): self._convert_version_62() self._convert_version_63() self._convert_version_64() + self._convert_version_65() self.put('seed_version', FINAL_SEED_VERSION) # just to be sure def _convert_wallet_type(self): @@ -1288,6 +1289,32 @@ class WalletDBUpgrader(Logger): self.data['lightning_payments'] = new_payment_infos self.data['seed_version'] = 64 + def _convert_version_65(self): + """Store channel_id instead of short_channel_id in ReceivedMPPHtlc""" + if not self._is_upgrade_method_needed(64, 64): + return + + channels = self.data.get('channels', {}) + def scid_to_channel_id(scid): + for channel_id, channel_data in channels.items(): + if scid == channel_data.get('short_channel_id'): + return channel_id + raise KeyError(f"missing {scid=} in channels") + + mpp_sets = self.data.get('received_mpp_htlcs', {}) + new_mpp_sets = {} + for payment_key, mpp_set in mpp_sets.items(): + resolution, htlc_list, parent_set_key = mpp_set + new_htlc_list = [] + for htlc_data_tuple in htlc_list: + scid, update_add_htlc, onion = htlc_data_tuple + channel_id = scid_to_channel_id(scid) + new_htlc_list.append((channel_id, update_add_htlc, onion)) + new_mpp_sets[payment_key] = (resolution, new_htlc_list, parent_set_key) + + self.data['received_mpp_htlcs'] = new_mpp_sets + self.data['seed_version'] = 65 + def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): return From 125a921cc4b52b8e6e4d3e63aa6562cebe6d3baf Mon Sep 17 00:00:00 2001 From: f321x Date: Tue, 9 Dec 2025 14:31:12 +0100 Subject: [PATCH 3/6] lnworker: add invoice features to PaymentInfo class Adds the invoice features to the `PaymentInfo` class so we can check if the sender respects our requested features (e.g. if they tried to send mpp if we requested no mpp). --- electrum/lnworker.py | 51 ++++++++++++++++++++++++++----------------- electrum/wallet_db.py | 19 +++++++++++++++- tests/test_lnpeer.py | 4 ++++ 3 files changed, 53 insertions(+), 21 deletions(-) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 1db9d0de0..938b05ba8 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -134,6 +134,7 @@ class PaymentInfo: min_final_cltv_delta: int expiry_delay: int creation_ts: int = dataclasses.field(default_factory=lambda: int(time.time())) + invoice_features: LnFeatures @property def expiration_ts(self): @@ -147,6 +148,7 @@ class PaymentInfo: assert isinstance(self.min_final_cltv_delta, int) assert isinstance(self.expiry_delay, int) and self.expiry_delay > 0 assert isinstance(self.creation_ts, int) + assert isinstance(self.invoice_features, LnFeatures) def __post_init__(self): self.validate() @@ -903,8 +905,8 @@ class LNWallet(LNWorker): LNWorker.__init__(self, self.node_keypair, features, config=self.config) self.lnwatcher = LNWatcher(self) self.lnrater: LNRater = None - # lightning_payments: "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts - self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int]] + # "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts, invoice_features + self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int, int]] self._preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage self._bolt11_cache = {} # note: this sweep_address is only used as fallback; as it might result in address-reuse @@ -1628,6 +1630,7 @@ class LNWallet(LNWorker): status=PR_UNPAID, min_final_cltv_delta=min_final_cltv_delta, expiry_delay=LN_EXPIRY_NEVER, + invoice_features=invoice_features, ) self.save_payment_info(info) self.wallet.set_label(key, lnaddr.get_description()) @@ -2332,6 +2335,16 @@ class LNWallet(LNWorker): route[-1].node_features |= invoice_features return route + def _get_invoice_features(self, amount_msat: Optional[int]) -> LnFeatures: + invoice_features = self.features.for_invoice() + if not self.uses_trampoline(): + invoice_features &= ~ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM + needs_jit: bool = self.receive_requires_jit_channel(amount_msat) + if needs_jit: + # jit only works with single htlcs, mpp will cause LSP to open channels for each htlc + invoice_features &= ~ LnFeatures.BASIC_MPP_OPT & ~ LnFeatures.BASIC_MPP_REQ + return invoice_features + def clear_invoices_cache(self): self._bolt11_cache.clear() @@ -2351,15 +2364,8 @@ class LNWallet(LNWorker): assert amount_msat is None or amount_msat > 0 timestamp = int(time.time()) - needs_jit: bool = self.receive_requires_jit_channel(amount_msat) - routing_hints = self.calc_routing_hints_for_invoice(amount_msat, channels=channels, needs_jit=needs_jit) - self.logger.info(f"creating bolt11 invoice with routing_hints: {routing_hints}, jit: {needs_jit}, sat: {(amount_msat or 0) // 1000}") - invoice_features = self.features.for_invoice() - if not self.uses_trampoline(): - invoice_features &= ~ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM - if needs_jit: - # jit only works with single htlcs, mpp will cause LSP to open channels for each htlc - invoice_features &= ~ LnFeatures.BASIC_MPP_OPT & ~ LnFeatures.BASIC_MPP_REQ + routing_hints = self.calc_routing_hints_for_invoice(amount_msat, channels=channels) + self.logger.info(f"creating bolt11 invoice with routing_hints: {routing_hints}, sat: {(amount_msat or 0) // 1000}") payment_secret = self.get_payment_secret(payment_info.payment_hash) amount_btc = amount_msat/Decimal(COIN*1000) if amount_msat else None min_final_cltv_delta = payment_info.min_final_cltv_delta + MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE @@ -2370,7 +2376,7 @@ class LNWallet(LNWorker): ('d', message), ('c', min_final_cltv_delta), ('x', payment_info.expiry_delay), - ('9', invoice_features), + ('9', payment_info.invoice_features), ('f', fallback_address), ] + routing_hints, date=timestamp, @@ -2401,13 +2407,15 @@ class LNWallet(LNWorker): payment_preimage = os.urandom(32) payment_hash = sha256(payment_preimage) min_final_cltv_delta = min_final_cltv_delta or MIN_FINAL_CLTV_DELTA_ACCEPTED + invoice_features = self._get_invoice_features(amount_msat) info = PaymentInfo( payment_hash=payment_hash, amount_msat=amount_msat, direction=RECEIVED, status=PR_UNPAID, min_final_cltv_delta=min_final_cltv_delta, - expiry_delay=exp_delay + expiry_delay=exp_delay, + invoice_features=invoice_features, ) self.save_preimage(payment_hash, payment_preimage, write_to_disk=False) self.save_payment_info(info, write_to_disk=False) @@ -2514,7 +2522,7 @@ class LNWallet(LNWorker): with self.lock: if key in self.payment_info: stored_tuple = self.payment_info[key] - amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts = stored_tuple + amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts, invoice_features = stored_tuple return PaymentInfo( payment_hash=payment_hash, amount_msat=amount_msat, @@ -2523,6 +2531,7 @@ class LNWallet(LNWorker): min_final_cltv_delta=min_final_cltv_delta, expiry_delay=expiry_delay, creation_ts=creation_ts, + invoice_features=LnFeatures(invoice_features), ) return None @@ -2533,14 +2542,15 @@ class LNWallet(LNWorker): min_final_cltv_delta: int, exp_delay: int, ): - amount = lightning_amount_sat * 1000 if lightning_amount_sat else None + amount_msat = lightning_amount_sat * 1000 if lightning_amount_sat else None info = PaymentInfo( payment_hash=payment_hash, - amount_msat=amount, + amount_msat=amount_msat, direction=RECEIVED, status=PR_UNPAID, min_final_cltv_delta=min_final_cltv_delta, expiry_delay=exp_delay, + invoice_features=self._get_invoice_features(amount_msat), ) self.save_payment_info(info, write_to_disk=False) @@ -2574,7 +2584,7 @@ class LNWallet(LNWorker): if info != dataclasses.replace(old_info, status=info.status): # differs more than in status. let's fail raise Exception(f"payment_hash already in use: {info=} != {old_info=}") - v = info.amount_msat, info.status, info.min_final_cltv_delta, info.expiry_delay, info.creation_ts + v = info.amount_msat, info.status, info.min_final_cltv_delta, info.expiry_delay, info.creation_ts, int(info.invoice_features) self.payment_info[info.db_key] = v if write_to_disk: self.wallet.save_db() @@ -2907,10 +2917,11 @@ class LNWallet(LNWorker): else: self.logger.info(f'htlc_failed: waiting for other htlcs to fail (phash={payment_hash.hex()})') - def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None, needs_jit=False): + def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None): """calculate routing hints (BOLT-11 'r' field)""" routing_hints = [] - if needs_jit: + if self.receive_requires_jit_channel(amount_msat): + self.logger.debug(f"will request just-in-time channel") node_id, rest = extract_nodeid(self.config.ZEROCONF_TRUSTED_NODE) alias_or_scid = self.get_static_jit_scid_alias() routing_hints.append(('r', [(node_id, alias_or_scid, 0, 0, 144)])) @@ -3093,7 +3104,7 @@ class LNWallet(LNWorker): # check if zeroconf is accepted and client has trusted zeroconf node configured return False try: - node_id = extract_nodeid(self.wallet.config.ZEROCONF_TRUSTED_NODE)[0] + node_id = extract_nodeid(self.config.ZEROCONF_TRUSTED_NODE)[0] except ConnStringFormatError: # invalid connection string return False diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index 068499c2d..688ab4c17 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -69,7 +69,7 @@ class WalletUnfinished(WalletFileException): # seed_version is now used for the version of the wallet file OLD_SEED_VERSION = 4 # electrum versions < 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0 -FINAL_SEED_VERSION = 65 # electrum >= 2.7 will set this to prevent +FINAL_SEED_VERSION = 66 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format @@ -237,6 +237,7 @@ class WalletDBUpgrader(Logger): self._convert_version_63() self._convert_version_64() self._convert_version_65() + self._convert_version_66() self.put('seed_version', FINAL_SEED_VERSION) # just to be sure def _convert_wallet_type(self): @@ -1315,6 +1316,22 @@ class WalletDBUpgrader(Logger): self.data['received_mpp_htlcs'] = new_mpp_sets self.data['seed_version'] = 65 + def _convert_version_66(self): + """Add invoice features to PaymentInfo""" + if not self._is_upgrade_method_needed(65, 65): + return + + new_payment_infos = {} + old_payment_infos = self.data.get('lightning_payments', {}) + for key, old_v in old_payment_infos.items(): + amount_msat, status, min_final_cltv_expiry, expiry, creation_ts = old_v + invoice_features = 147712 # + new_v = (amount_msat, status, min_final_cltv_expiry, expiry, creation_ts, invoice_features) + new_payment_infos[key] = new_v + + self.data['lightning_payments'] = new_payment_infos + self.data['seed_version'] = 66 + def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): return diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 57def01cb..68e4d129e 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -359,6 +359,9 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): is_payment_bundle_complete = LNWallet.is_payment_bundle_complete delete_payment_bundle = LNWallet.delete_payment_bundle _process_htlc_log = LNWallet._process_htlc_log + _get_invoice_features = LNWallet._get_invoice_features + receive_requires_jit_channel = LNWallet.receive_requires_jit_channel + can_get_zeroconf_channel = LNWallet.can_get_zeroconf_channel class MockTransport: @@ -594,6 +597,7 @@ class TestPeer(ElectrumTestCase): status=PR_UNPAID, min_final_cltv_delta=min_final_cltv_delta, expiry_delay=expiry or LN_EXPIRY_NEVER, + invoice_features=invoice_features, ) w2.save_payment_info(info) lnaddr1 = LnAddr( From 183d426e9324ed07825d8936931fd781347ee278 Mon Sep 17 00:00:00 2001 From: f321x Date: Mon, 8 Dec 2025 14:18:05 +0100 Subject: [PATCH 4/6] lnpeer: fail htlcs if we get unwanted mpp Fail incoming htlcs if we receive a payment consisting of multiple parts if we signaled to not want mpp in the invoice. --- electrum/lnpeer.py | 22 +++++++++++++++++++--- tests/test_lnpeer.py | 1 + 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index ec1362944..59480b6a1 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2240,6 +2240,10 @@ class Peer(Logger, EventListener): elif htlc.timestamp > info.expiration_ts: # the set will get failed too if now > exp_ts _log_fail_reason(f"not accepting htlc for expired invoice") raise exc_incorrect_or_unknown_pd + elif not info.invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and total_msat > htlc.amount_msat: + # in _check_unfulfilled_htlc_set we check the count to prevent mpp through overpayment + _log_fail_reason(f"got mpp but we requested no mpp in the invoice: {total_msat=} > {htlc.amount_msat=}") + raise exc_incorrect_or_unknown_pd expected_payment_secret = self.lnworker.get_payment_secret(payment_hash) if not util.constant_time_compare(payment_secret_from_onion, expected_payment_secret): @@ -3074,7 +3078,8 @@ class Peer(Logger, EventListener): return OnionFailureCode.MPP_TIMEOUT, None, None if mpp_set.resolution == RecvMPPResolution.WAITING: - # calculate the sum of just in time channel opening fees + # calculate the sum of just in time channel opening fees, note jit only supports + # single part payments for now, this is enforced by checking against the invoice features htlc_channels = [self.lnworker.channels[channel_id] for channel_id in set(h.channel_id for h in mpp_set.htlcs)] jit_opening_fees_msat = sum((c.jit_opening_fee or 0) for c in htlc_channels) @@ -3095,15 +3100,21 @@ class Peer(Logger, EventListener): trampoline_payment_key = (payment_hash + trampoline_payment_secret).hex() if trampoline_payment_key and trampoline_payment_key != payment_key: + if jit_opening_fees_msat: + # for jit openings we only accept a single htlc + expected_amount_first_stage = any_trampoline_onion.total_msat - jit_opening_fees_msat + else: + expected_amount_first_stage = any_trampoline_onion.amt_to_forward + # first stage of trampoline payment, the first stage must never get set COMPLETE - if amount_msat >= (any_trampoline_onion.amt_to_forward - jit_opening_fees_msat): + if amount_msat >= expected_amount_first_stage: # setting the parent key will mark the htlcs to be moved to the parent set self.logger.debug(f"trampoline part complete. {len(mpp_set.htlcs)=}, " f"{amount_msat=}. setting parent key: {trampoline_payment_key}") self.lnworker.received_mpp_htlcs[payment_key] = mpp_set._replace( parent_set_key=trampoline_payment_key, ) - elif amount_msat >= (total_msat - jit_opening_fees_msat): + elif amount_msat >= (total_msat - jit_opening_fees_msat): # regular mpp or 2nd stage trampoline # set mpp_set as completed as we have received the full total_msat mpp_set = self.lnworker.set_mpp_resolution( payment_key=payment_key, @@ -3127,6 +3138,11 @@ class Peer(Logger, EventListener): if payment_info is None: _log_fail_reason(f"payment info has been deleted") return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + elif not payment_info.invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and len(mpp_set.htlcs) > 1: + # in _check_unfulfilled_htlc we already check amount == total_amount, however someone could + # send us multiple htlcs that all pay the full amount, so we also check the htlc count + _log_fail_reason(f"got mpp but we requested no mpp in the invoice: {len(mpp_set.htlcs)=}") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None # check invoice expiry, fail set if the invoice has expired before it was completed if mpp_set.resolution == RecvMPPResolution.WAITING: diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 68e4d129e..07662f1d5 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -1452,6 +1452,7 @@ class TestPeerDirect(TestPeer): async def run_test(test_trampoline: bool): alice_channel, bob_channel = create_test_channels() alice_peer, bob_peer, alice_wallet, bob_wallet, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + bob_wallet.features |= LnFeatures.BASIC_MPP_OPT lnaddr1, pay_req1 = self.prepare_invoice(bob_wallet, amount_msat=10_000) if test_trampoline: From 7c01d9db75fc74e768d8c20e04c153c377a5c097 Mon Sep 17 00:00:00 2001 From: f321x Date: Tue, 9 Dec 2025 17:49:25 +0100 Subject: [PATCH 5/6] tests: lnpeer: add test_reject_mpp_for_non_mpp_invoice --- tests/test_lnpeer.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 07662f1d5..941a4a12f 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -1071,6 +1071,46 @@ class TestPeerDirect(TestPeer): for _test_trampoline in [False, True]: await run_test(_test_trampoline) + async def test_reject_mpp_for_non_mpp_invoice(self): + """Test that we reject a payment if it is mpp and we didn't signal support for mpp in the invoice""" + async def run_test(test_trampoline): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + w1.config.TEST_FORCE_MPP = True # force alice to send mpp + + if test_trampoline: + await self._activate_trampoline(w1) + await self._activate_trampoline(w2) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), + } + + lnaddr, pay_req = self.prepare_invoice(w2) + self.assertFalse(lnaddr.get_features().supports(LnFeatures.BASIC_MPP_OPT)) + self.assertFalse(lnaddr.get_features().supports(LnFeatures.BASIC_MPP_REQ)) + + async def try_pay_invoice_with_mpp(pay_req: Invoice, w1=w1): + result, log = await w1.pay_invoice(pay_req) + if not result: + raise PaymentFailure() + raise PaymentDone() + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(try_pay_invoice_with_mpp(pay_req)) + + with self.assertRaises(PaymentFailure): + await f() + + for _test_trampoline in [False, True]: + await run_test(_test_trampoline) + async def test_reject_multiple_payments_of_same_invoice(self): """Tests that new htlcs paying an invoice that has already been paid will get rejected.""" async def run_test(test_trampoline): From 745318d1ec3de87ded57cd8fdb16ed6213ffdda0 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Wed, 10 Dec 2025 15:53:17 +0000 Subject: [PATCH 6/6] wallet_db: convert_version_66: trivial simplification --- electrum/wallet_db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index 688ab4c17..dfb4bbb23 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -1325,7 +1325,7 @@ class WalletDBUpgrader(Logger): old_payment_infos = self.data.get('lightning_payments', {}) for key, old_v in old_payment_infos.items(): amount_msat, status, min_final_cltv_expiry, expiry, creation_ts = old_v - invoice_features = 147712 # + invoice_features = 0x24100 # new_v = (amount_msat, status, min_final_cltv_expiry, expiry, creation_ts, invoice_features) new_payment_infos[key] = new_v