diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 99d45456f..5e2c7df72 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -12,7 +12,7 @@ import concurrent from concurrent import futures from functools import lru_cache from unittest import mock -from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence +from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence, Mapping from types import MappingProxyType import time import statistics @@ -323,6 +323,24 @@ depleted_channel = { } _GRAPH_DEFINITIONS = { + # A -- B + 'single_chan' : { + 'alice': { + 'channels': { + 'bob': { + 'local_balance_msat': 10 * bitcoin.COIN * 1000 // 2, + 'remote_balance_msat': 10 * bitcoin.COIN * 1000 // 2, + }, + }, + }, + 'bob': { + }, + }, + # A + # high fee / \ low fee + # B C + # high fee \ / low fee + # D 'square_graph': { 'alice': { 'channels': { @@ -353,6 +371,7 @@ _GRAPH_DEFINITIONS = { 'dave': { }, }, + # A -- B -- C -- D -- E 'line_graph': { 'alice': { 'channels': { @@ -515,7 +534,12 @@ class TestPeer(ElectrumTestCase): raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') w2.register_hold_invoice(payment_hash, cb) - def prepare_chans_and_peers_in_graph(self, graph_definition) -> Graph: + def prepare_chans_and_peers_in_graph( + self, + graph_definition, + *, + channels: Mapping[Tuple[str, str], Channel] = None, + ) -> Graph: workers = {} # type: Dict[str, MockLNWallet] # create workers @@ -524,32 +548,42 @@ class TestPeer(ElectrumTestCase): self._lnworkers_created.extend(list(workers.values())) keys = {name: w.node_keypair for name, w in workers.items()} - channels = {} # type: Dict[Tuple[str, str], Channel] + if channels is None: + channels = {} # type: Dict[Tuple[str, str], Channel] transports = {} peers = {} # create channels for a, definition in graph_definition.items(): for b, channel_def in definition.get('channels', {}).items(): - channel_ab, channel_ba = create_test_channels( - alice_name=a, - bob_name=b, - alice_pubkey=keys[a].pubkey, - bob_pubkey=keys[b].pubkey, - local_msat=channel_def['local_balance_msat'], - remote_msat=channel_def['remote_balance_msat'], - anchor_outputs=self.TEST_ANCHOR_CHANNELS - ) - channels[(a, b)], channels[(b, a)] = channel_ab, channel_ba + if ((a, b) in channels) or ((b, a) in channels): + # if either chan direction is present, both must be present + channel_ab = channels[(a, b)] + channel_ba = channels[(b, a)] + else: # create new chans now + channel_ab, channel_ba = create_test_channels( + alice_name=a, + bob_name=b, + alice_pubkey=keys[a].pubkey, + bob_pubkey=keys[b].pubkey, + local_msat=channel_def['local_balance_msat'], + remote_msat=channel_def['remote_balance_msat'], + anchor_outputs=self.TEST_ANCHOR_CHANNELS + ) + channels[(a, b)], channels[(b, a)] = channel_ab, channel_ba workers[a]._add_channel(channel_ab) workers[b]._add_channel(channel_ba) transport_ab, transport_ba = transport_pair(keys[a], keys[b], channel_ab.name, channel_ba.name) transports[(a, b)], transports[(b, a)] = transport_ab, transport_ba # set fees - channel_ab.forwarding_fee_proportional_millionths = channel_def['local_fee_rate_millionths'] - channel_ab.forwarding_fee_base_msat = channel_def['local_base_fee_msat'] - channel_ba.forwarding_fee_proportional_millionths = channel_def['remote_fee_rate_millionths'] - channel_ba.forwarding_fee_base_msat = channel_def['remote_base_fee_msat'] + if 'local_fee_rate_millionths' in channel_def: + channel_ab.forwarding_fee_proportional_millionths = channel_def['local_fee_rate_millionths'] + if 'local_base_fee_msat' in channel_def: + channel_ab.forwarding_fee_base_msat = channel_def['local_base_fee_msat'] + if 'remote_fee_rate_millionths' in channel_def: + channel_ba.forwarding_fee_proportional_millionths = channel_def['remote_fee_rate_millionths'] + if 'remote_base_fee_msat' in channel_def: + channel_ba.forwarding_fee_base_msat = channel_def['remote_base_fee_msat'] # create peers for ab in channels.keys(): @@ -630,34 +664,27 @@ class TestPeerDirect(TestPeer): def prepare_peers( self, alice_channel: Channel, bob_channel: Channel, ): - w1 = MockLNWallet(name=bob_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS) - w2 = MockLNWallet(name=alice_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS) - k1 = w1.node_keypair - k2 = w2.node_keypair - alice_channel.node_id = k2.pubkey - bob_channel.node_id = k1.pubkey - alice_channel.storage['node_id'] = alice_channel.node_id - bob_channel.storage['node_id'] = bob_channel.node_id - t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name) - w1._add_channel(alice_channel) - w2._add_channel(bob_channel) - self._lnworkers_created.extend([w1, w2]) - p1 = PeerInTests(w1, k2.pubkey, t1) - p2 = PeerInTests(w2, k1.pubkey, t2) - w1.lnpeermgr._peers[p1.pubkey] = p1 - w2.lnpeermgr._peers[p2.pubkey] = p2 - # mark_open won't work if state is already OPEN. - # so set it to FUNDED - alice_channel._state = ChannelState.FUNDED - bob_channel._state = ChannelState.FUNDED - # this populates the channel graph: - p1.mark_open(alice_channel) - p2.mark_open(bob_channel) + graph = self.prepare_chans_and_peers_in_graph( + self.GRAPH_DEFINITIONS['single_chan'], + channels={('alice', 'bob'): alice_channel, ('bob', 'alice'): bob_channel}, + ) + c1, c2 = graph.channels.values() + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() + + # FIXME xxxxx horrible ugly hack: + c1.node_id = w2.node_keypair.pubkey + c2.node_id = w1.node_keypair.pubkey + c1.storage['node_id'] = c1.node_id + c2.storage['node_id'] = c2.node_id + return p1, p2, w1, w2 async def test_reestablish(self): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + alice_channel, bob_channel = graph.channels.values() + for chan in (alice_channel, bob_channel): chan.peer_state = PeerState.DISCONNECTED async def reestablish(): @@ -852,8 +879,9 @@ class TestPeerDirect(TestPeer): test_bundle_timeout=False ): """Alice pays Bob a single HTLC via direct channel.""" - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() results = {} async def pay(lnaddr, pay_req): self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) @@ -938,8 +966,9 @@ class TestPeerDirect(TestPeer): await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True, test_failure=True) async def test_check_invoice_before_payment(self): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() async def try_paying_some_invoices(): # feature bits: unknown even fbit invoice_features = w2.features.for_invoice() | (1 << 990) # add undefined even fbit @@ -975,8 +1004,9 @@ class TestPeerDirect(TestPeer): rejected immediately upon receiving them. """ async def run_test(test_trampoline): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() async def try_pay_with_too_low_final_cltv_delta(lnaddr, w1=w1, w2=w2): self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) @@ -1018,8 +1048,9 @@ class TestPeerDirect(TestPeer): async def test_reject_payment_for_expired_invoice(self): """Tests that new htlcs paying an invoice that has already been expired will get rejected.""" async def run_test(test_trampoline): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() # create lightning invoice in the past, so it is expired with mock.patch('time.time', return_value=int(time.time()) - 10000): @@ -1061,8 +1092,9 @@ class TestPeerDirect(TestPeer): async def test_reject_mpp_for_non_mpp_invoice(self): """Test that we reject a payment if it is mpp and we didn't signal support for mpp in the invoice""" async def run_test(test_trampoline): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() w1.config.TEST_FORCE_MPP = True # force alice to send mpp if test_trampoline: @@ -1101,8 +1133,9 @@ class TestPeerDirect(TestPeer): async def test_reject_multiple_payments_of_same_invoice(self): """Tests that new htlcs paying an invoice that has already been paid will get rejected.""" async def run_test(test_trampoline): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() lnaddr, _pay_req = self.prepare_invoice(w2) @@ -1145,8 +1178,10 @@ class TestPeerDirect(TestPeer): before sending 'commitment_signed'. Neither party should fulfill the respective HTLCs until those are irrevocably committed to. """ - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() + alice_channel, bob_channel = graph.channels.values() async def pay(): await util.wait_for2(p1.initialized, 1) await util.wait_for2(p2.initialized, 1) @@ -1220,8 +1255,10 @@ class TestPeerDirect(TestPeer): #@unittest.skip("too expensive") async def test_payments_stresstest(self): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() + alice_channel, bob_channel = graph.channels.values() alice_init_balance_msat = alice_channel.balance(HTLCOwner.LOCAL) bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL) num_payments = 50 @@ -1252,8 +1289,10 @@ class TestPeerDirect(TestPeer): # - Alice sends htlc1: 0.1 BTC, H1, S1 (total_msat=1 BTC) # - Alice sends htlc2: 0.9 BTC, H2, S1 (total_msat=1 BTC) # - Bob(victim) reveals preimage for H1 and fulfills htlc1 (fails other) - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() + alice_channel, bob_channel = graph.channels.values() async def pay(): self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED)) self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash, direction=RECEIVED)) @@ -1325,8 +1364,10 @@ class TestPeerDirect(TestPeer): # - Alice sends htlc1: 0.1 BTC (total_msat=0.2 BTC) # - Alice sends htlc2: 0.1 BTC (total_msat=1 BTC) # - Bob(victim) reveals preimage and fulfills htlc2 (fails other) - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() + alice_channel, bob_channel = graph.channels.values() async def pay(): self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED)) @@ -1392,8 +1433,10 @@ class TestPeerDirect(TestPeer): """Alice gets two htlcs as part of a mpp, one has a cltv too close to expiry and will get failed. Test that the other htlc won't get settled if the mpp isn't complete anymore after failing the other htlc. """ - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() + alice_channel, bob_channel = graph.channels.values() async def pay(): await util.wait_for2(p1.initialized, 1) await util.wait_for2(p2.initialized, 1) @@ -1477,8 +1520,10 @@ class TestPeerDirect(TestPeer): and the sender gets a second chance to pay the same invoice. """ async def run_test(test_trampoline: bool): - alice_channel, bob_channel = create_test_channels() - alice_peer, bob_peer, alice_wallet, bob_wallet = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + alice_peer, bob_peer = graph.peers.values() + alice_wallet, bob_wallet = graph.workers.values() + alice_channel, bob_channel = graph.channels.values() bob_wallet.features |= LnFeatures.BASIC_MPP_OPT lnaddr1, pay_req1 = self.prepare_invoice(bob_wallet, amount_msat=10_000) @@ -1591,8 +1636,10 @@ class TestPeerDirect(TestPeer): # )) async def _test_shutdown(self, alice_fee, bob_fee, alice_fee_range=None, bob_fee_range=None): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() + alice_channel, bob_channel = graph.channels.values() w1.network.config.TEST_SHUTDOWN_FEE = alice_fee w2.network.config.TEST_SHUTDOWN_FEE = bob_fee if alice_fee_range is not None: @@ -1628,8 +1675,9 @@ class TestPeerDirect(TestPeer): await gath async def test_warning(self): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + alice_channel, bob_channel = graph.channels.values() async def action(): await util.wait_for2(p1.initialized, 1) @@ -1640,8 +1688,9 @@ class TestPeerDirect(TestPeer): await gath async def test_error(self): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + alice_channel, bob_channel = graph.channels.values() async def action(): await util.wait_for2(p1.initialized, 1) @@ -1730,8 +1779,10 @@ class TestPeerDirect(TestPeer): self.assertEqual(1, len(closing_tx.get_output_idxs_from_address(bob_uss_addr))) async def test_channel_usage_after_closing(self): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() + alice_channel, bob_channel = graph.channels.values() lnaddr, pay_req = self.prepare_invoice(w2) lnaddr = w1._check_bolt11_invoice(pay_req.lightning_invoice) @@ -1771,8 +1822,8 @@ class TestPeerDirect(TestPeer): await f() async def test_sending_weird_messages_that_should_be_ignored(self): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() async def send_weird_messages(): await util.wait_for2(p1.initialized, 1) @@ -1802,8 +1853,8 @@ class TestPeerDirect(TestPeer): await f() async def test_sending_weird_messages__unknown_even_type(self): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() async def send_weird_messages(): await util.wait_for2(p1.initialized, 1) @@ -1831,8 +1882,8 @@ class TestPeerDirect(TestPeer): self.assertTrue(isinstance(failing_task.exception().__cause__, lnmsg.UnknownMandatoryMsgType)) async def test_sending_weird_messages__known_msg_with_insufficient_length(self): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() async def send_weird_messages(): await util.wait_for2(p1.initialized, 1) @@ -1871,8 +1922,9 @@ class TestPeerDirect(TestPeer): which behave differently and use the persisted `LNWallet.dont_expire_htlcs` dict. """ async def run_test(test_trampoline): - alice_channel, bob_channel = create_test_channels() - alice_p, bob_p, alice_w, bob_w = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + alice_p, bob_p = graph.peers.values() + alice_w, bob_w = graph.workers.values() lnaddr, pay_req = self.prepare_invoice(bob_w, min_final_cltv_delta=150) del bob_w._preimages[pay_req.rhash] # del preimage so bob doesn't settle @@ -2011,8 +2063,9 @@ class TestPeerDirect(TestPeer): preimage is available. """ async def run_test(test_trampoline, test_expiry): - alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['single_chan']) + p1, p2 = graph.peers.values() + w1, w2 = graph.workers.values() if test_trampoline: await self._activate_trampoline(w1) # declare bob as trampoline node