diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 1ea9ae0ed..9d9ddc699 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses import shutil import copy import tempfile @@ -11,6 +12,7 @@ import concurrent from concurrent import futures from unittest import mock from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence +from types import MappingProxyType import time from aiorpcx import timeout_after, TaskTimeout @@ -39,7 +41,7 @@ from electrum.lnmsg import encode_msg, decode_msg from electrum import lnmsg from electrum.logging import console_stderr_handler, Logger from electrum.lnworker import PaymentInfo, RECEIVED -from electrum.lnonion import OnionFailureCode, OnionRoutingFailure +from electrum.lnonion import OnionFailureCode, OnionRoutingFailure, OnionHopsDataSingle, OnionPacket from electrum.lnutil import LOCAL, REMOTE, UpdateAddHtlc, RecvMPPResolution from electrum.invoices import PR_PAID, PR_UNPAID, Invoice, LN_EXPIRY_NEVER from electrum.interface import GracefulDisconnect @@ -2481,6 +2483,78 @@ class TestPeerForwarding(TestPeer): attempts=30, # the default used in LNWallet.pay_invoice() ) + async def test_forwarder_fails_for_inconsistent_trampoline_onions(self): + """ + verify that the receiver of a trampoline forwarding fails the mpp set + if the trampoline onions are not similar + In this test alice tries to forward through bob, however in one trampoline onion she sends + amt_to_forward is off by one msat. Bob should compare the trampoline onions and fail the set. + """ + + # store a modified trampoline onion to be injected into lnworker.new_onion_packet later when sending the htlcs + modified_trampoline_onion = None + def modified_new_onion_packet_trampoline(payment_path_pubkeys, session_key, hops_data: List[OnionHopsDataSingle], **kwargs): + nonlocal modified_trampoline_onion + assert modified_trampoline_onion is None, "this mock should get called only once" + modified_hops_data = copy.copy(hops_data) + # first payload (i[0]) is for bob who is supposed to forward the trampoline payment, in this + # test he should fail the incoming htlcs as their trampolines are not similar + new_payload = dict(modified_hops_data[0].payload) + amt_to_forward = dict(new_payload['amt_to_forward']) + amt_to_forward['amt_to_forward'] -= 1 + new_payload['amt_to_forward'] = amt_to_forward + modified_hops_data[0] = dataclasses.replace(modified_hops_data[0], payload=new_payload) + self.logger.debug(f"{modified_hops_data=}\nsent_{hops_data=}") + modified_trampoline_onion = electrum.lnonion.new_onion_packet( + payment_path_pubkeys, + session_key, + modified_hops_data, + **kwargs + ) + # return the unmodified onion + return electrum.lnonion.new_onion_packet( + payment_path_pubkeys, + session_key, + hops_data, + **kwargs + ) + + # this gets called in lnworker per sent htlc, for one sent htlc we inject the modified trampoline + # onion created before in the mock above + def modified_new_onion_packet_lnworker(payment_path_pubkeys, session_key, hops_data: List[OnionHopsDataSingle], **kwargs): + nonlocal modified_trampoline_onion + hops_data = copy.copy(hops_data) + if modified_trampoline_onion: + assert isinstance(modified_trampoline_onion, OnionPacket) + assert len(hops_data) == 1 + new_payload = dict(hops_data[0].payload) + new_payload['trampoline_onion_packet'] = { + "version": modified_trampoline_onion.version, + "public_key": modified_trampoline_onion.public_key, + "hops_data": modified_trampoline_onion.hops_data, + "hmac": modified_trampoline_onion.hmac, + } + hops_data[0] = dataclasses.replace(hops_data[0], payload=MappingProxyType(new_payload)) + modified_trampoline_onion = None + return electrum.lnonion.new_onion_packet( + payment_path_pubkeys, + session_key, + hops_data, + **kwargs + ) + + graph = self.create_square_graph(direct=False, test_mpp_consolidation=True, is_legacy=True) + alice = graph.workers['alice'] + alice.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 6 # set high so the first attempt would succeed + with self.assertRaises(PaymentFailure): + with mock.patch('electrum.trampoline.new_onion_packet', side_effect=modified_new_onion_packet_trampoline), \ + mock.patch('electrum.lnworker.new_onion_packet', side_effect=modified_new_onion_packet_lnworker): + await self._run_trampoline_payment(graph, attempts=1) + bob_alice_channel = graph.channels[('bob', 'alice')] + bob_hm = bob_alice_channel.hm + assert len(bob_hm.all_htlcs_ever()) == 2 + assert all(bob_hm.was_htlc_failed(htlc_id=htlc.htlc_id, htlc_proposer=HTLCOwner.REMOTE) for (_, htlc) in bob_hm.all_htlcs_ever()) + class TestPeerDirectAnchors(TestPeerDirect): TEST_ANCHOR_CHANNELS = True