1
0

lnpeer: fail htlcs if we get unwanted mpp

Fail incoming htlcs if we receive a payment consisting of multiple parts
if we signaled to not want mpp in the invoice.
This commit is contained in:
f321x
2025-12-08 14:18:05 +01:00
parent 125a921cc4
commit 183d426e93
2 changed files with 20 additions and 3 deletions

View File

@@ -2240,6 +2240,10 @@ class Peer(Logger, EventListener):
elif htlc.timestamp > info.expiration_ts: # the set will get failed too if now > exp_ts
_log_fail_reason(f"not accepting htlc for expired invoice")
raise exc_incorrect_or_unknown_pd
elif not info.invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and total_msat > htlc.amount_msat:
# in _check_unfulfilled_htlc_set we check the count to prevent mpp through overpayment
_log_fail_reason(f"got mpp but we requested no mpp in the invoice: {total_msat=} > {htlc.amount_msat=}")
raise exc_incorrect_or_unknown_pd
expected_payment_secret = self.lnworker.get_payment_secret(payment_hash)
if not util.constant_time_compare(payment_secret_from_onion, expected_payment_secret):
@@ -3074,7 +3078,8 @@ class Peer(Logger, EventListener):
return OnionFailureCode.MPP_TIMEOUT, None, None
if mpp_set.resolution == RecvMPPResolution.WAITING:
# calculate the sum of just in time channel opening fees
# calculate the sum of just in time channel opening fees, note jit only supports
# single part payments for now, this is enforced by checking against the invoice features
htlc_channels = [self.lnworker.channels[channel_id] for channel_id in set(h.channel_id for h in mpp_set.htlcs)]
jit_opening_fees_msat = sum((c.jit_opening_fee or 0) for c in htlc_channels)
@@ -3095,15 +3100,21 @@ class Peer(Logger, EventListener):
trampoline_payment_key = (payment_hash + trampoline_payment_secret).hex()
if trampoline_payment_key and trampoline_payment_key != payment_key:
if jit_opening_fees_msat:
# for jit openings we only accept a single htlc
expected_amount_first_stage = any_trampoline_onion.total_msat - jit_opening_fees_msat
else:
expected_amount_first_stage = any_trampoline_onion.amt_to_forward
# first stage of trampoline payment, the first stage must never get set COMPLETE
if amount_msat >= (any_trampoline_onion.amt_to_forward - jit_opening_fees_msat):
if amount_msat >= expected_amount_first_stage:
# setting the parent key will mark the htlcs to be moved to the parent set
self.logger.debug(f"trampoline part complete. {len(mpp_set.htlcs)=}, "
f"{amount_msat=}. setting parent key: {trampoline_payment_key}")
self.lnworker.received_mpp_htlcs[payment_key] = mpp_set._replace(
parent_set_key=trampoline_payment_key,
)
elif amount_msat >= (total_msat - jit_opening_fees_msat):
elif amount_msat >= (total_msat - jit_opening_fees_msat): # regular mpp or 2nd stage trampoline
# set mpp_set as completed as we have received the full total_msat
mpp_set = self.lnworker.set_mpp_resolution(
payment_key=payment_key,
@@ -3127,6 +3138,11 @@ class Peer(Logger, EventListener):
if payment_info is None:
_log_fail_reason(f"payment info has been deleted")
return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None
elif not payment_info.invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and len(mpp_set.htlcs) > 1:
# in _check_unfulfilled_htlc we already check amount == total_amount, however someone could
# send us multiple htlcs that all pay the full amount, so we also check the htlc count
_log_fail_reason(f"got mpp but we requested no mpp in the invoice: {len(mpp_set.htlcs)=}")
return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None
# check invoice expiry, fail set if the invoice has expired before it was completed
if mpp_set.resolution == RecvMPPResolution.WAITING:

View File

@@ -1452,6 +1452,7 @@ class TestPeerDirect(TestPeer):
async def run_test(test_trampoline: bool):
alice_channel, bob_channel = create_test_channels()
alice_peer, bob_peer, alice_wallet, bob_wallet, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
bob_wallet.features |= LnFeatures.BASIC_MPP_OPT
lnaddr1, pay_req1 = self.prepare_invoice(bob_wallet, amount_msat=10_000)
if test_trampoline: