From 91b98240dceac7dd49be861131bb37aa09117b2e Mon Sep 17 00:00:00 2001 From: SomberNight Date: Thu, 18 Dec 2025 19:24:47 +0000 Subject: [PATCH] tests: lnpeer: follow-up prev: rm horrible ugly hack --- tests/test_lnchannel.py | 35 ++++++++++++++++++++++++------ tests/test_lnpeer.py | 48 +++++++++++++++++++++-------------------- 2 files changed, 54 insertions(+), 29 deletions(-) diff --git a/tests/test_lnchannel.py b/tests/test_lnchannel.py index b55281ca6..8aebab7c6 100644 --- a/tests/test_lnchannel.py +++ b/tests/test_lnchannel.py @@ -31,6 +31,7 @@ from pprint import pformat import logging import dataclasses import time +from typing import TYPE_CHECKING from electrum import bitcoin from electrum import lnpeer @@ -47,6 +48,10 @@ from electrum.coinchooser import PRNG from . import ElectrumTestCase +if TYPE_CHECKING: + from .test_lnpeer import MockLNWallet + + one_bitcoin_in_msat = bitcoin.COIN * 1000 @@ -134,15 +139,33 @@ def bip32(sequence): return k -def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None, - alice_name="alice", bob_name="bob", - alice_pubkey=b"\x01"*33, bob_pubkey=b"\x02"*33, random_seed=None, - anchor_outputs=False, - local_max_inflight=None, remote_max_inflight=None, - max_accepted_htlcs=5) -> tuple[Channel, Channel]: +def create_test_channels( + *, + alice_lnwallet: 'MockLNWallet' = None, + bob_lnwallet: 'MockLNWallet' = None, + feerate=6000, + local_msat=None, + remote_msat=None, + random_seed=None, + anchor_outputs=False, + local_max_inflight=None, + remote_max_inflight=None, + max_accepted_htlcs=5, +) -> tuple[Channel, Channel]: if random_seed is None: # needed for deterministic randomness random_seed = os.urandom(32) random_gen = PRNG(random_seed) + if alice_lnwallet or bob_lnwallet: + assert alice_lnwallet and bob_lnwallet, "either both or neither lnwallet must be set" + alice_name = alice_lnwallet.name + bob_name = bob_lnwallet.name + alice_pubkey = alice_lnwallet.node_keypair.pubkey + bob_pubkey = bob_lnwallet.node_keypair.pubkey + else: + alice_name = "alice" + bob_name = "bob", + alice_pubkey = b"\x01" * 33 + bob_pubkey = b"\x02" * 33 funding_txid = binascii.hexlify(random_gen.get_bytes(32)).decode("ascii") funding_index = 0 funding_sat = ((local_msat + remote_msat) // 1000) if local_msat is not None and remote_msat is not None else (bitcoin.COIN * 10) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 5e2c7df72..d2d03b710 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -534,18 +534,23 @@ class TestPeer(ElectrumTestCase): raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') w2.register_hold_invoice(payment_hash, cb) + def prepare_lnwallets(self, graph_definition) -> Mapping[str, MockLNWallet]: + workers = {} # type: Dict[str, MockLNWallet] + for a, definition in graph_definition.items(): + workers[a] = MockLNWallet(name=a, has_anchors=self.TEST_ANCHOR_CHANNELS) + self._lnworkers_created.extend(list(workers.values())) + return workers + def prepare_chans_and_peers_in_graph( self, graph_definition, *, + workers: Dict[str, MockLNWallet] = None, channels: Mapping[Tuple[str, str], Channel] = None, ) -> Graph: - workers = {} # type: Dict[str, MockLNWallet] - # create workers - for a, definition in graph_definition.items(): - workers[a] = MockLNWallet(name=a, has_anchors=self.TEST_ANCHOR_CHANNELS) - self._lnworkers_created.extend(list(workers.values())) + if workers is None: + workers = self.prepare_lnwallets(graph_definition=graph_definition) keys = {name: w.node_keypair for name, w in workers.items()} if channels is None: @@ -562,10 +567,8 @@ class TestPeer(ElectrumTestCase): channel_ba = channels[(b, a)] else: # create new chans now channel_ab, channel_ba = create_test_channels( - alice_name=a, - bob_name=b, - alice_pubkey=keys[a].pubkey, - bob_pubkey=keys[b].pubkey, + alice_lnwallet=workers[a], + bob_lnwallet=workers[b], local_msat=channel_def['local_balance_msat'], remote_msat=channel_def['remote_balance_msat'], anchor_outputs=self.TEST_ANCHOR_CHANNELS @@ -668,16 +671,8 @@ class TestPeerDirect(TestPeer): self.GRAPH_DEFINITIONS['single_chan'], channels={('alice', 'bob'): alice_channel, ('bob', 'alice'): bob_channel}, ) - c1, c2 = graph.channels.values() p1, p2 = graph.peers.values() w1, w2 = graph.workers.values() - - # FIXME xxxxx horrible ugly hack: - c1.node_id = w2.node_keypair.pubkey - c2.node_id = w1.node_keypair.pubkey - c1.storage['node_id'] = c1.node_id - c2.storage['node_id'] = c2.node_id - return p1, p2, w1, w2 async def test_reestablish(self): @@ -701,8 +696,9 @@ class TestPeerDirect(TestPeer): async def test_reestablish_with_old_state(self): async def f(alice_slow: bool, bob_slow: bool): random_seed = os.urandom(32) - alice_channel, bob_channel = create_test_channels(random_seed=random_seed) - alice_channel_0, bob_channel_0 = create_test_channels(random_seed=random_seed) # these are identical + alice_lnwallet, bob_lnwallet = self.prepare_lnwallets(self.GRAPH_DEFINITIONS['single_chan']).values() + alice_channel, bob_channel = create_test_channels(random_seed=random_seed, alice_lnwallet=alice_lnwallet, bob_lnwallet=bob_lnwallet) + alice_channel_0, bob_channel_0 = create_test_channels(random_seed=random_seed, alice_lnwallet=alice_lnwallet, bob_lnwallet=bob_lnwallet) # these are identical p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) lnaddr, pay_req = self.prepare_invoice(w2) async def pay(): @@ -776,7 +772,8 @@ class TestPeerDirect(TestPeer): ----add--> ----sig--> """ - chan_AB, chan_BA = create_test_channels() + alice_lnwallet, bob_lnwallet = self.prepare_lnwallets(self.GRAPH_DEFINITIONS['single_chan']).values() + chan_AB, chan_BA = create_test_channels(alice_lnwallet=alice_lnwallet, bob_lnwallet=bob_lnwallet) # note: we don't start peer.htlc_switch() so that the fake htlcs are left alone. async def f(): p1, p2, w1, w2 = self.prepare_peers(chan_AB, chan_BA) @@ -831,7 +828,8 @@ class TestPeerDirect(TestPeer): ----sig--> ----rev--> """ - chan_AB, chan_BA = create_test_channels() + alice_lnwallet, bob_lnwallet = self.prepare_lnwallets(self.GRAPH_DEFINITIONS['single_chan']).values() + chan_AB, chan_BA = create_test_channels(alice_lnwallet=alice_lnwallet, bob_lnwallet=bob_lnwallet) # note: we don't start peer.htlc_switch() so that the fake htlcs are left alone. async def f(): p1, p2, w1, w2 = self.prepare_peers(chan_AB, chan_BA) @@ -1703,7 +1701,8 @@ class TestPeerDirect(TestPeer): await gath async def test_close_upfront_shutdown_script(self): - alice_channel, bob_channel = create_test_channels() + alice_lnwallet, bob_lnwallet = self.prepare_lnwallets(self.GRAPH_DEFINITIONS['single_chan']).values() + alice_channel, bob_channel = create_test_channels(alice_lnwallet=alice_lnwallet, bob_lnwallet=bob_lnwallet) # create upfront shutdown script for bob, alice doesn't use upfront # shutdown script @@ -1995,7 +1994,10 @@ class TestPeerDirect(TestPeer): $ py-spy record -o flamegraph.svg --subprocesses -- python -m pytest tests/test_lnpeer.py::TestPeerDirect::test_htlc_switch_iteration_benchmark """ NUM_ITERATIONS = 25 - alice_channel, bob_channel = create_test_channels(max_accepted_htlcs=20) + alice_lnwallet, bob_lnwallet = self.prepare_lnwallets(self.GRAPH_DEFINITIONS['single_chan']).values() + alice_channel, bob_channel = create_test_channels( + alice_lnwallet=alice_lnwallet, bob_lnwallet=bob_lnwallet, max_accepted_htlcs=20, + ) alice_p, bob_p, alice_w, bob_w = self.prepare_peers(alice_channel, bob_channel) await self._activate_trampoline(alice_w)