lnpeer: fail htlcs if we get unwanted mpp
Fail incoming htlcs if we receive a payment consisting of multiple parts if we signaled to not want mpp in the invoice.
This commit is contained in:
@@ -2240,6 +2240,10 @@ class Peer(Logger, EventListener):
|
||||
elif htlc.timestamp > info.expiration_ts: # the set will get failed too if now > exp_ts
|
||||
_log_fail_reason(f"not accepting htlc for expired invoice")
|
||||
raise exc_incorrect_or_unknown_pd
|
||||
elif not info.invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and total_msat > htlc.amount_msat:
|
||||
# in _check_unfulfilled_htlc_set we check the count to prevent mpp through overpayment
|
||||
_log_fail_reason(f"got mpp but we requested no mpp in the invoice: {total_msat=} > {htlc.amount_msat=}")
|
||||
raise exc_incorrect_or_unknown_pd
|
||||
|
||||
expected_payment_secret = self.lnworker.get_payment_secret(payment_hash)
|
||||
if not util.constant_time_compare(payment_secret_from_onion, expected_payment_secret):
|
||||
@@ -3074,7 +3078,8 @@ class Peer(Logger, EventListener):
|
||||
return OnionFailureCode.MPP_TIMEOUT, None, None
|
||||
|
||||
if mpp_set.resolution == RecvMPPResolution.WAITING:
|
||||
# calculate the sum of just in time channel opening fees
|
||||
# calculate the sum of just in time channel opening fees, note jit only supports
|
||||
# single part payments for now, this is enforced by checking against the invoice features
|
||||
htlc_channels = [self.lnworker.channels[channel_id] for channel_id in set(h.channel_id for h in mpp_set.htlcs)]
|
||||
jit_opening_fees_msat = sum((c.jit_opening_fee or 0) for c in htlc_channels)
|
||||
|
||||
@@ -3095,15 +3100,21 @@ class Peer(Logger, EventListener):
|
||||
trampoline_payment_key = (payment_hash + trampoline_payment_secret).hex()
|
||||
|
||||
if trampoline_payment_key and trampoline_payment_key != payment_key:
|
||||
if jit_opening_fees_msat:
|
||||
# for jit openings we only accept a single htlc
|
||||
expected_amount_first_stage = any_trampoline_onion.total_msat - jit_opening_fees_msat
|
||||
else:
|
||||
expected_amount_first_stage = any_trampoline_onion.amt_to_forward
|
||||
|
||||
# first stage of trampoline payment, the first stage must never get set COMPLETE
|
||||
if amount_msat >= (any_trampoline_onion.amt_to_forward - jit_opening_fees_msat):
|
||||
if amount_msat >= expected_amount_first_stage:
|
||||
# setting the parent key will mark the htlcs to be moved to the parent set
|
||||
self.logger.debug(f"trampoline part complete. {len(mpp_set.htlcs)=}, "
|
||||
f"{amount_msat=}. setting parent key: {trampoline_payment_key}")
|
||||
self.lnworker.received_mpp_htlcs[payment_key] = mpp_set._replace(
|
||||
parent_set_key=trampoline_payment_key,
|
||||
)
|
||||
elif amount_msat >= (total_msat - jit_opening_fees_msat):
|
||||
elif amount_msat >= (total_msat - jit_opening_fees_msat): # regular mpp or 2nd stage trampoline
|
||||
# set mpp_set as completed as we have received the full total_msat
|
||||
mpp_set = self.lnworker.set_mpp_resolution(
|
||||
payment_key=payment_key,
|
||||
@@ -3127,6 +3138,11 @@ class Peer(Logger, EventListener):
|
||||
if payment_info is None:
|
||||
_log_fail_reason(f"payment info has been deleted")
|
||||
return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None
|
||||
elif not payment_info.invoice_features.supports(LnFeatures.BASIC_MPP_OPT) and len(mpp_set.htlcs) > 1:
|
||||
# in _check_unfulfilled_htlc we already check amount == total_amount, however someone could
|
||||
# send us multiple htlcs that all pay the full amount, so we also check the htlc count
|
||||
_log_fail_reason(f"got mpp but we requested no mpp in the invoice: {len(mpp_set.htlcs)=}")
|
||||
return OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, None, None
|
||||
|
||||
# check invoice expiry, fail set if the invoice has expired before it was completed
|
||||
if mpp_set.resolution == RecvMPPResolution.WAITING:
|
||||
|
||||
@@ -1452,6 +1452,7 @@ class TestPeerDirect(TestPeer):
|
||||
async def run_test(test_trampoline: bool):
|
||||
alice_channel, bob_channel = create_test_channels()
|
||||
alice_peer, bob_peer, alice_wallet, bob_wallet, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||
bob_wallet.features |= LnFeatures.BASIC_MPP_OPT
|
||||
lnaddr1, pay_req1 = self.prepare_invoice(bob_wallet, amount_msat=10_000)
|
||||
|
||||
if test_trampoline:
|
||||
|
||||
Reference in New Issue
Block a user