diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 7f8f36ce5..6f887db76 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -67,7 +67,7 @@ def keypair(): class MockNetwork: - def __init__(self, tx_queue, *, config: SimpleConfig): + def __init__(self, *, config: SimpleConfig): self.lnwatcher = None self.interface = None self.fee_estimates = FeeTimeEstimates() @@ -78,7 +78,7 @@ class MockNetwork: self.channel_db.data_loaded.set() self.path_finder = LNPathFinder(self.channel_db) self.lngossip = MockLNGossip() - self.tx_queue = tx_queue + self.tx_queue = asyncio.Queue() self._blockchain = MockBlockchain() def get_local_height(self): @@ -88,8 +88,7 @@ class MockNetwork: return self._blockchain async def broadcast_transaction(self, tx): - if self.tx_queue: - await self.tx_queue.put(tx) + await self.tx_queue.put(tx) async def try_broadcasting(self, tx, name): await self.broadcast_transaction(tx) @@ -160,7 +159,7 @@ class MockLNWallet(LNWallet): TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0 MPP_SPLIT_PART_FRACTION = 1 # this disables the forced splitting - def __init__(self, *, tx_queue, name, has_anchors, ln_xprv: str = None): + def __init__(self, *, name, has_anchors, ln_xprv: str = None): self.name = name self._user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-") @@ -168,7 +167,7 @@ class MockLNWallet(LNWallet): self.config.ENABLE_ANCHOR_CHANNELS = has_anchors self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0 - network = MockNetwork(tx_queue, config=self.config) + network = MockNetwork(config=self.config) wallet = restore_wallet_from_text__for_unittest( "9dk", path=None, passphrase=name, config=self.config)['wallet'] # type: Abstract_Wallet @@ -561,9 +560,8 @@ class TestPeerDirect(TestPeer): def prepare_peers( self, alice_channel: Channel, bob_channel: Channel, ): - q1, q2 = asyncio.Queue(), asyncio.Queue() - w1 = MockLNWallet(tx_queue=q1, name=bob_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS) - w2 = MockLNWallet(tx_queue=q2, name=alice_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS) + w1 = MockLNWallet(name=bob_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS) + w2 = MockLNWallet(name=alice_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS) k1 = w1.node_keypair k2 = w2.node_keypair alice_channel.node_id = k2.pubkey @@ -585,11 +583,11 @@ class TestPeerDirect(TestPeer): # this populates the channel graph: p1.mark_open(alice_channel) p2.mark_open(bob_channel) - return p1, p2, w1, w2, q1, q2 + return p1, p2, w1, w2 async def test_reestablish(self): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) for chan in (alice_channel, bob_channel): chan.peer_state = PeerState.DISCONNECTED async def reestablish(): @@ -608,7 +606,7 @@ class TestPeerDirect(TestPeer): random_seed = os.urandom(32) alice_channel, bob_channel = create_test_channels(random_seed=random_seed) alice_channel_0, bob_channel_0 = create_test_channels(random_seed=random_seed) # these are identical - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) lnaddr, pay_req = self.prepare_invoice(w2) async def pay(): result, log = await w1.pay_invoice(pay_req) @@ -617,7 +615,7 @@ class TestPeerDirect(TestPeer): gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) with self.assertRaises(asyncio.CancelledError): await gath - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel_0, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel_0, bob_channel) for chan in (alice_channel_0, bob_channel): chan.peer_state = PeerState.DISCONNECTED @@ -684,7 +682,7 @@ class TestPeerDirect(TestPeer): chan_AB, chan_BA = create_test_channels() # note: we don't start peer.htlc_switch() so that the fake htlcs are left alone. async def f(): - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA) + p1, p2, w1, w2 = self.prepare_peers(chan_AB, chan_BA) async with OldTaskGroup() as group: await group.spawn(p1._message_loop()) await group.spawn(p2._message_loop()) @@ -698,7 +696,7 @@ class TestPeerDirect(TestPeer): await group.cancel_remaining() # simulating disconnection. recreate transports. self.logger.info("simulating disconnection. recreating transports.") - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA) + p1, p2, w1, w2 = self.prepare_peers(chan_AB, chan_BA) for chan in (chan_AB, chan_BA): chan.peer_state = PeerState.DISCONNECTED async with OldTaskGroup() as group: @@ -739,7 +737,7 @@ class TestPeerDirect(TestPeer): chan_AB, chan_BA = create_test_channels() # note: we don't start peer.htlc_switch() so that the fake htlcs are left alone. async def f(): - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA) + p1, p2, w1, w2 = self.prepare_peers(chan_AB, chan_BA) async with OldTaskGroup() as group: await group.spawn(p1._message_loop()) await group.spawn(p2._message_loop()) @@ -754,7 +752,7 @@ class TestPeerDirect(TestPeer): await group.cancel_remaining() # simulating disconnection. recreate transports. self.logger.info("simulating disconnection. recreating transports.") - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA) + p1, p2, w1, w2 = self.prepare_peers(chan_AB, chan_BA) for chan in (chan_AB, chan_BA): chan.peer_state = PeerState.DISCONNECTED async with OldTaskGroup() as group: @@ -785,7 +783,7 @@ class TestPeerDirect(TestPeer): ): """Alice pays Bob a single HTLC via direct channel.""" alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) results = {} async def pay(lnaddr, pay_req): self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) @@ -871,7 +869,7 @@ class TestPeerDirect(TestPeer): async def test_check_invoice_before_payment(self): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) async def try_paying_some_invoices(): # feature bits: unknown even fbit invoice_features = w2.features.for_invoice() | (1 << 990) # add undefined even fbit @@ -908,7 +906,7 @@ class TestPeerDirect(TestPeer): """ async def run_test(test_trampoline): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) async def try_pay_with_too_low_final_cltv_delta(lnaddr, w1=w1, w2=w2): self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED)) @@ -951,7 +949,7 @@ class TestPeerDirect(TestPeer): """Tests that new htlcs paying an invoice that has already been expired will get rejected.""" async def run_test(test_trampoline): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) # create lightning invoice in the past, so it is expired with mock.patch('time.time', return_value=int(time.time()) - 10000): @@ -994,7 +992,7 @@ class TestPeerDirect(TestPeer): """Test that we reject a payment if it is mpp and we didn't signal support for mpp in the invoice""" async def run_test(test_trampoline): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) w1.config.TEST_FORCE_MPP = True # force alice to send mpp if test_trampoline: @@ -1034,7 +1032,7 @@ class TestPeerDirect(TestPeer): """Tests that new htlcs paying an invoice that has already been paid will get rejected.""" async def run_test(test_trampoline): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) lnaddr, _pay_req = self.prepare_invoice(w2) @@ -1078,7 +1076,7 @@ class TestPeerDirect(TestPeer): the respective HTLCs until those are irrevocably committed to. """ alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) async def pay(): await util.wait_for2(p1.initialized, 1) await util.wait_for2(p2.initialized, 1) @@ -1153,7 +1151,7 @@ class TestPeerDirect(TestPeer): #@unittest.skip("too expensive") async def test_payments_stresstest(self): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) alice_init_balance_msat = alice_channel.balance(HTLCOwner.LOCAL) bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL) num_payments = 50 @@ -1185,7 +1183,7 @@ class TestPeerDirect(TestPeer): # - Alice sends htlc2: 0.9 BTC, H2, S1 (total_msat=1 BTC) # - Bob(victim) reveals preimage for H1 and fulfills htlc1 (fails other) alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) async def pay(): self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED)) self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash, direction=RECEIVED)) @@ -1258,7 +1256,7 @@ class TestPeerDirect(TestPeer): # - Alice sends htlc2: 0.1 BTC (total_msat=1 BTC) # - Bob(victim) reveals preimage and fulfills htlc2 (fails other) alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) async def pay(): self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED)) @@ -1325,7 +1323,7 @@ class TestPeerDirect(TestPeer): Test that the other htlc won't get settled if the mpp isn't complete anymore after failing the other htlc. """ alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) async def pay(): await util.wait_for2(p1.initialized, 1) await util.wait_for2(p2.initialized, 1) @@ -1410,7 +1408,7 @@ class TestPeerDirect(TestPeer): """ async def run_test(test_trampoline: bool): alice_channel, bob_channel = create_test_channels() - alice_peer, bob_peer, alice_wallet, bob_wallet, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + alice_peer, bob_peer, alice_wallet, bob_wallet = self.prepare_peers(alice_channel, bob_channel) bob_wallet.features |= LnFeatures.BASIC_MPP_OPT lnaddr1, pay_req1 = self.prepare_invoice(bob_wallet, amount_msat=10_000) @@ -1524,7 +1522,7 @@ class TestPeerDirect(TestPeer): async def _test_shutdown(self, alice_fee, bob_fee, alice_fee_range=None, bob_fee_range=None): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) w1.network.config.TEST_SHUTDOWN_FEE = alice_fee w2.network.config.TEST_SHUTDOWN_FEE = bob_fee if alice_fee_range is not None: @@ -1561,7 +1559,7 @@ class TestPeerDirect(TestPeer): async def test_warning(self): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) async def action(): await util.wait_for2(p1.initialized, 1) @@ -1573,7 +1571,7 @@ class TestPeerDirect(TestPeer): async def test_error(self): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) async def action(): await util.wait_for2(p1.initialized, 1) @@ -1600,7 +1598,7 @@ class TestPeerDirect(TestPeer): # setting the upfront shutdown script in the channel config bob_channel.config[HTLCOwner.LOCAL].upfront_shutdown_script = b'' - p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) w1.network.config.FEE_POLICY = 'feerate:5000' w2.network.config.FEE_POLICY = 'feerate:1000' @@ -1624,15 +1622,15 @@ class TestPeerDirect(TestPeer): with self.assertRaises(GracefulDisconnect): await test() # check that neither party broadcast a closing tx (as it was not even signed) - self.assertEqual(0, q1.qsize()) - self.assertEqual(0, q2.qsize()) + self.assertEqual(0, w1.network.tx_queue.qsize()) + self.assertEqual(0, w2.network.tx_queue.qsize()) # -- new scenario: # bob sends the same upfront_shutdown_script has he announced alice_channel.config[HTLCOwner.REMOTE].upfront_shutdown_script = bob_uss bob_channel.config[HTLCOwner.LOCAL].upfront_shutdown_script = bob_uss - p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) w1.network.config.FEE_POLICY = 'feerate:5000' w2.network.config.FEE_POLICY = 'feerate:1000' @@ -1656,14 +1654,14 @@ class TestPeerDirect(TestPeer): await test() # check if p1 has broadcast the closing tx, and if it pays to Bob's uss - self.assertEqual(1, q1.qsize()) - closing_tx = q1.get_nowait() # type: Transaction + self.assertEqual(1, w1.network.tx_queue.qsize()) + closing_tx = w1.network.tx_queue.get_nowait() # type: Transaction self.assertEqual(2, len(closing_tx.outputs())) self.assertEqual(1, len(closing_tx.get_output_idxs_from_address(bob_uss_addr))) async def test_channel_usage_after_closing(self): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) lnaddr, pay_req = self.prepare_invoice(w2) lnaddr = w1._check_bolt11_invoice(pay_req.lightning_invoice) @@ -1673,7 +1671,7 @@ class TestPeerDirect(TestPeer): await w1.force_close_channel(alice_channel.channel_id) # check if a tx (commitment transaction) was broadcasted: - assert q1.qsize() == 1 + assert w1.network.tx_queue.qsize() == 1 with self.assertRaises(NoPathFound) as e: await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr) @@ -1704,7 +1702,7 @@ class TestPeerDirect(TestPeer): async def test_sending_weird_messages_that_should_be_ignored(self): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) async def send_weird_messages(): await util.wait_for2(p1.initialized, 1) @@ -1735,7 +1733,7 @@ class TestPeerDirect(TestPeer): async def test_sending_weird_messages__unknown_even_type(self): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) async def send_weird_messages(): await util.wait_for2(p1.initialized, 1) @@ -1764,7 +1762,7 @@ class TestPeerDirect(TestPeer): async def test_sending_weird_messages__known_msg_with_insufficient_length(self): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) async def send_weird_messages(): await util.wait_for2(p1.initialized, 1) @@ -1804,7 +1802,7 @@ class TestPeerDirect(TestPeer): """ async def run_test(test_trampoline): alice_channel, bob_channel = create_test_channels() - alice_p, bob_p, alice_w, bob_w, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + alice_p, bob_p, alice_w, bob_w = self.prepare_peers(alice_channel, bob_channel) lnaddr, pay_req = self.prepare_invoice(bob_w, min_final_cltv_delta=150) del bob_w._preimages[pay_req.rhash] # del preimage so bob doesn't settle @@ -1876,7 +1874,7 @@ class TestPeerDirect(TestPeer): """ NUM_ITERATIONS = 25 alice_channel, bob_channel = create_test_channels(max_accepted_htlcs=20) - alice_p, bob_p, alice_w, bob_w, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + alice_p, bob_p, alice_w, bob_w = self.prepare_peers(alice_channel, bob_channel) await self._activate_trampoline(alice_w) electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { @@ -1944,7 +1942,7 @@ class TestPeerDirect(TestPeer): """ async def run_test(test_trampoline, test_expiry): alice_channel, bob_channel = create_test_channels() - p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) if test_trampoline: await self._activate_trampoline(w1) # declare bob as trampoline node @@ -2017,11 +2015,10 @@ class TestPeerForwarding(TestPeer): def prepare_chans_and_peers_in_graph(self, graph_definition) -> Graph: workers = {} # type: Dict[str, MockLNWallet] - txs_queues = {k: asyncio.Queue() for k in graph_definition} # create workers for a, definition in graph_definition.items(): - workers[a] = MockLNWallet(tx_queue=txs_queues[a], name=a, has_anchors=self.TEST_ANCHOR_CHANNELS) + workers[a] = MockLNWallet(name=a, has_anchors=self.TEST_ANCHOR_CHANNELS) self._lnworkers_created.extend(list(workers.values())) keys = {name: w.node_keypair for name, w in workers.items()} diff --git a/tests/test_onion_message.py b/tests/test_onion_message.py index ce23b98d6..f33e2368f 100644 --- a/tests/test_onion_message.py +++ b/tests/test_onion_message.py @@ -352,8 +352,7 @@ class TestOnionMessageManager(ElectrumTestCase): async def test_request_and_reply(self): n = MockNetwork() - q1, q2 = asyncio.Queue(), asyncio.Queue() - lnw = MockLNWallet(tx_queue=q1, name='test_request_and_reply', has_anchors=False) + lnw = MockLNWallet(name='test_request_and_reply', has_anchors=False) def slow(*args, **kwargs): time.sleep(2*TIME_STEP) @@ -399,8 +398,7 @@ class TestOnionMessageManager(ElectrumTestCase): async def test_forward(self): n = MockNetwork() - q1 = asyncio.Queue() - lnw = MockLNWallet(tx_queue=q1, name='alice', has_anchors=False) + lnw = MockLNWallet(name='alice', has_anchors=False) lnw.node_keypair = self.alice self.was_sent = False @@ -437,8 +435,7 @@ class TestOnionMessageManager(ElectrumTestCase): async def test_receive_unsolicited(self): n = MockNetwork() - q1 = asyncio.Queue() - lnw = MockLNWallet(tx_queue=q1, name='dave', has_anchors=False) + lnw = MockLNWallet(name='dave', has_anchors=False) lnw.node_keypair = self.dave t = OnionMessageManager(lnw)