diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 84bbbb616..2ad720332 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -435,7 +435,42 @@ _GRAPH_DEFINITIONS = { }, 'dave': { }, - } + }, + 'line_graph': { + 'alice': { + 'channels': { + 'bob': low_fee_channel.copy(), + }, + }, + 'bob': { # Trampoline Forwarder + 'channels': { + 'carol': low_fee_channel.copy(), + }, + 'config': { + SimpleConfig.EXPERIMENTAL_LN_FORWARD_PAYMENTS: True, + SimpleConfig.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS: True, + }, + }, + 'carol': { + 'channels': { + 'dave': low_fee_channel.copy(), + }, + 'config': { + SimpleConfig.EXPERIMENTAL_LN_FORWARD_PAYMENTS: True, + }, + }, + 'dave': { # Trampoline Forwarder + 'channels': { + 'edward': low_fee_channel.copy(), + }, + 'config': { + SimpleConfig.EXPERIMENTAL_LN_FORWARD_PAYMENTS: True, + SimpleConfig.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS: True, + }, + }, + 'edward': { + }, + }, } @@ -450,6 +485,20 @@ class PaymentTimeout(Exception): pass class SuccessfulTest(Exception): pass +def inject_chan_into_gossipdb(*, channel_db: ChannelDB, graph: Graph, node1name: str, node2name: str) -> None: + chan_ann_raw = graph.channels[(node1name, node2name)].construct_channel_announcement_without_sigs()[0] + chan_ann_dict = decode_msg(chan_ann_raw)[1] + channel_db.add_channel_announcements(chan_ann_dict, trusted=True) + + chan_upd1_raw = graph.channels[(node1name, node2name)].get_outgoing_gossip_channel_update() + chan_upd1_dict = decode_msg(chan_upd1_raw)[1] + channel_db.add_channel_update(chan_upd1_dict, verify=False) + + chan_upd2_raw = graph.channels[(node2name, node1name)].get_outgoing_gossip_channel_update() + chan_upd2_dict = decode_msg(chan_upd2_raw)[1] + channel_db.add_channel_update(chan_upd2_dict, verify=False) + + class TestPeer(ElectrumTestCase): TESTNET = True @@ -1497,9 +1546,6 @@ class TestPeerForwarding(TestPeer): print(f" {keys[a].pubkey.hex()}") return graph - async def test_payment_multihop(self): - graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['square_graph']) - async def test_payment_multihop(self): graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['square_graph']) peers = graph.peers.values() @@ -1890,48 +1936,52 @@ class TestPeerForwarding(TestPeer): await f() async def _run_trampoline_payment( - self, graph, *, + self, graph: Graph, *, include_routing_hints=True, test_hold_invoice=False, test_failure=False, - attempts=2): + attempts=2, + sender_name="alice", + destination_name="dave", + tf_names=("bob", "carol"), + ): - alice_w = graph.workers['alice'] - dave_w = graph.workers['dave'] + sender_w = graph.workers[sender_name] + dest_w = graph.workers[destination_name] async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash)) - result, log = await alice_w.pay_invoice(pay_req, attempts=attempts) + self.assertEqual(PR_UNPAID, dest_w.get_payment_status(lnaddr.paymenthash)) + result, log = await sender_w.pay_invoice(pay_req, attempts=attempts) async with OldTaskGroup() as g: for peer in peers: await g.spawn(peer.wait_one_htlc_switch_iteration()) for peer in peers: self.assertEqual(len(peer.lnworker.active_forwardings), 0) if result: - self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, dest_w.get_payment_status(lnaddr.paymenthash)) raise PaymentDone() else: raise NoPathFound() async def f(): - await self._activate_trampoline(alice_w) + await self._activate_trampoline(sender_w) async with OldTaskGroup() as group: for peer in peers: await group.spawn(peer._message_loop()) await group.spawn(peer.htlc_switch()) for peer in peers: await peer.initialized - lnaddr, pay_req = self.prepare_invoice(dave_w, include_routing_hints=include_routing_hints) - self.prepare_recipient(dave_w, lnaddr.paymenthash, test_hold_invoice, test_failure) + lnaddr, pay_req = self.prepare_invoice(dest_w, include_routing_hints=include_routing_hints) + self.prepare_recipient(dest_w, lnaddr.paymenthash, test_hold_invoice, test_failure) await group.spawn(pay(lnaddr, pay_req)) peers = graph.peers.values() # declare routing nodes as trampoline nodes - electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { - graph.workers['bob'].name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers['bob'].node_keypair.pubkey), - graph.workers['carol'].name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers['carol'].node_keypair.pubkey), - } + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {} + for tf_name in tf_names: + peer_addr = LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers[tf_name].node_keypair.pubkey) + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS[graph.workers[tf_name].name] = peer_addr await f() @@ -1973,6 +2023,7 @@ class TestPeerForwarding(TestPeer): await self._run_trampoline_payment(graph, test_hold_invoice=True, test_failure=True) async def test_payment_trampoline_legacy(self): + # alice -> T1_bob -> carol -> dave with self.assertRaises(PaymentDone): graph = self.create_square_graph(direct=False, is_legacy=True) await self._run_trampoline_payment(graph, include_routing_hints=True) @@ -1980,17 +2031,28 @@ class TestPeerForwarding(TestPeer): graph = self.create_square_graph(direct=False, is_legacy=True) await self._run_trampoline_payment(graph, include_routing_hints=False) - async def test_payment_trampoline_e2e_direct(self): + async def test_payment_trampoline_e2e_alice_t1_dave(self): with self.assertRaises(PaymentDone): graph = self.create_square_graph(direct=True, is_legacy=False) await self._run_trampoline_payment(graph) - async def test_payment_trampoline_e2e_indirect(self): - # must use two trampolines + async def test_payment_trampoline_e2e_alice_t1_t2_dave(self): with self.assertRaises(PaymentDone): graph = self.create_square_graph(direct=False, is_legacy=False) await self._run_trampoline_payment(graph) + async def test_payment_trampoline_e2e_alice_t1_carol_t2_edward(self): + # alice -> T1_bob -> carol -> T2_dave -> edward + graph_definition = self.GRAPH_DEFINITIONS['line_graph'] + graph = self.prepare_chans_and_peers_in_graph(graph_definition) + inject_chan_into_gossipdb( + channel_db=graph.workers['bob'].channel_db, graph=graph, + node1name='carol', node2name='dave') + with self.assertRaises(PaymentDone): + await self._run_trampoline_payment( + graph, sender_name='alice', destination_name='edward',tf_names=('bob', 'dave')) + + class TestPeerDirectAnchors(TestPeerDirect): TEST_ANCHOR_CHANNELS = True