diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index faa32a9d3..72db12657 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2082,15 +2082,17 @@ class Peer(Logger, EventListener): chan.receive_htlc(htlc, onion_packet) util.trigger_callback('htlc_added', chan, htlc, RECEIVED) - def check_accepted_htlc( - self, *, - chan: Channel, + @staticmethod + def _check_accepted_final_htlc( + *, chan: Channel, htlc: UpdateAddHtlc, processed_onion: ProcessedOnionPacket, + is_trampoline_onion: bool = False, log_fail_reason: Callable[[str], None], ) -> tuple[bytes, int, int, OnionRoutingFailure]: """ - Perform checks that are invariant (results do not depend on height, network conditions, etc). + Perform checks that are invariant (results do not depend on height, network conditions, etc.) + for htlcs of which we are the receiver (forwarding htlcs will have their checks in maybe_forward_htlc). May raise OnionRoutingFailure """ assert processed_onion.are_we_final, processed_onion @@ -2120,11 +2122,13 @@ class Peer(Logger, EventListener): else: channel_opening_fee = 0 - if amt_to_forward > htlc.amount_msat: - log_fail_reason(f"amt_to_forward != htlc.amount_msat") - raise OnionRoutingFailure( - code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT, - data=htlc.amount_msat.to_bytes(8, byteorder="big")) + if not is_trampoline_onion: + # for inner trampoline onions amt_to_forward can be larger than the htlc amount + if amt_to_forward > htlc.amount_msat: + log_fail_reason(f"{amt_to_forward=} > {htlc.amount_msat=}") + raise OnionRoutingFailure( + code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT, + data=htlc.amount_msat.to_bytes(8, byteorder="big")) if (payment_secret_from_onion := processed_onion.payment_secret) is None: log_fail_reason(f"'payment_secret' missing from onion") @@ -2166,6 +2170,7 @@ class Peer(Logger, EventListener): chan: Channel, htlc: UpdateAddHtlc, processed_onion: ProcessedOnionPacket, + outer_onion_payment_secret: bytes = None, # used to group trampoline htlcs for forwarding onion_packet_bytes: bytes, already_forwarded: bool = False, ) -> Tuple[Optional[bytes], Optional[Tuple[str, Callable[[], Awaitable[Optional[str]]]]]]: @@ -2199,10 +2204,11 @@ class Peer(Logger, EventListener): local_height = chain.height() # parse parameters and perform checks that are invariant - payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd = self.check_accepted_htlc( + payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd = self._check_accepted_final_htlc( chan=chan, htlc=htlc, processed_onion=processed_onion, + is_trampoline_onion=bool(outer_onion_payment_secret), log_fail_reason=log_fail_reason) # payment key for final onions @@ -2244,6 +2250,7 @@ class Peer(Logger, EventListener): chan=chan, htlc=htlc, processed_onion=trampoline_onion, + outer_onion_payment_secret=payment_secret_from_onion, onion_packet_bytes=onion_packet_bytes, already_forwarded=already_forwarded, ) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index be70b08b1..360ede67e 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -2022,6 +2022,9 @@ class TestPeerForwarding(TestPeer): async def pay(lnaddr, pay_req): self.assertEqual(PR_UNPAID, dest_w.get_payment_status(lnaddr.paymenthash)) result, log = await sender_w.pay_invoice(pay_req, attempts=attempts) + async with OldTaskGroup() as g: + for peer in peers: + await g.spawn(peer.wait_one_htlc_switch_iteration()) async with OldTaskGroup() as g: for peer in peers: await g.spawn(peer.wait_one_htlc_switch_iteration()) @@ -2122,6 +2125,33 @@ class TestPeerForwarding(TestPeer): await self._run_trampoline_payment( graph, sender_name='alice', destination_name='edward',tf_names=('bob', 'dave')) + async def test_multi_trampoline_payment(self): + """ + Alice splits her payment to Dave between two trampoline forwarding nodes Carol and Bob. + This should test Multi-Trampoline MPP: + https://github.com/lightning/bolts/blob/bc7a1a0bc97b2293e7f43dd8a06529e5fdcf7cd2/proposals/trampoline.md#multi-trampoline-mpp + """ + graph_definition = self.GRAPH_DEFINITIONS['square_graph'] + # payment amount is 100_000_000 msat, size the channels so that alice must use both to succeed + graph_definition['alice']['channels']['bob']['local_balance_msat'] = int(100_000_000 * 0.75) + graph_definition['alice']['channels']['carol']['local_balance_msat'] = int(100_000_000 * 0.75) + g = self.prepare_chans_and_peers_in_graph(graph_definition) + w = g.workers['alice'], g.workers['carol'], g.workers['bob'], g.workers['dave'] + alice_w, carol_w, bob_w, dave_w = w + + alice_w.config.TEST_FORCE_MPP = True + bob_w.config.TEST_FORCE_MPP = True + carol_w.config.TEST_FORCE_MPP = True + dave_w.features |= LnFeatures.BASIC_MPP_OPT + + with self.assertRaises(PaymentDone): + await self._run_trampoline_payment( + g, + sender_name='alice', + destination_name='dave', + tf_names=('bob', 'carol'), + attempts=30, # the default used in LNWallet.pay_invoice() + ) class TestPeerDirectAnchors(TestPeerDirect): TEST_ANCHOR_CHANNELS = True