diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index ec1362944..59480b6a1 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2240,6 +2240,10 @@ class Peer(Logger, EventListener): elif htlc.timestamp > info.expiration_ts: # the set will get failed too if now > exp_ts _log_fail_reason(f"not accepting htlc for expired invoice") raise exc_incorrect_or_unknown_pd + elif not info.invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and total_msat > htlc.amount_msat: + # in _check_unfulfilled_htlc_set we check the count to prevent mpp through overpayment + _log_fail_reason(f"got mpp but we requested no mpp in the invoice: {total_msat=} > {htlc.amount_msat=}") + raise exc_incorrect_or_unknown_pd expected_payment_secret = self.lnworker.get_payment_secret(payment_hash) if not util.constant_time_compare(payment_secret_from_onion, expected_payment_secret): @@ -3074,7 +3078,8 @@ class Peer(Logger, EventListener): return OnionFailureCode.MPP_TIMEOUT, None, None if mpp_set.resolution == RecvMPPResolution.WAITING: - # calculate the sum of just in time channel opening fees + # calculate the sum of just in time channel opening fees, note jit only supports + # single part payments for now, this is enforced by checking against the invoice features htlc_channels = [self.lnworker.channels[channel_id] for channel_id in set(h.channel_id for h in mpp_set.htlcs)] jit_opening_fees_msat = sum((c.jit_opening_fee or 0) for c in htlc_channels) @@ -3095,15 +3100,21 @@ class Peer(Logger, EventListener): trampoline_payment_key = (payment_hash + trampoline_payment_secret).hex() if trampoline_payment_key and trampoline_payment_key != payment_key: + if jit_opening_fees_msat: + # for jit openings we only accept a single htlc + expected_amount_first_stage = any_trampoline_onion.total_msat - jit_opening_fees_msat + else: + expected_amount_first_stage = any_trampoline_onion.amt_to_forward + # first stage of trampoline payment, the first stage must never get set COMPLETE - if amount_msat >= (any_trampoline_onion.amt_to_forward - jit_opening_fees_msat): + if amount_msat >= expected_amount_first_stage: # setting the parent key will mark the htlcs to be moved to the parent set self.logger.debug(f"trampoline part complete. {len(mpp_set.htlcs)=}, " f"{amount_msat=}. setting parent key: {trampoline_payment_key}") self.lnworker.received_mpp_htlcs[payment_key] = mpp_set._replace( parent_set_key=trampoline_payment_key, ) - elif amount_msat >= (total_msat - jit_opening_fees_msat): + elif amount_msat >= (total_msat - jit_opening_fees_msat): # regular mpp or 2nd stage trampoline # set mpp_set as completed as we have received the full total_msat mpp_set = self.lnworker.set_mpp_resolution( payment_key=payment_key, @@ -3127,6 +3138,11 @@ class Peer(Logger, EventListener): if payment_info is None: _log_fail_reason(f"payment info has been deleted") return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None + elif not payment_info.invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and len(mpp_set.htlcs) > 1: + # in _check_unfulfilled_htlc we already check amount == total_amount, however someone could + # send us multiple htlcs that all pay the full amount, so we also check the htlc count + _log_fail_reason(f"got mpp but we requested no mpp in the invoice: {len(mpp_set.htlcs)=}") + return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None # check invoice expiry, fail set if the invoice has expired before it was completed if mpp_set.resolution == RecvMPPResolution.WAITING: diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 68e4d129e..07662f1d5 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -1452,6 +1452,7 @@ class TestPeerDirect(TestPeer): async def run_test(test_trampoline: bool): alice_channel, bob_channel = create_test_channels() alice_peer, bob_peer, alice_wallet, bob_wallet, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + bob_wallet.features |= LnFeatures.BASIC_MPP_OPT lnaddr1, pay_req1 = self.prepare_invoice(bob_wallet, amount_msat=10_000) if test_trampoline: