start asyncio loop in test_lnrouter and test_lnpeer
This commit is contained in:
@@ -1,4 +1,3 @@
|
|||||||
import unittest
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import tempfile
|
import tempfile
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
@@ -11,7 +10,7 @@ from electrum.ecc import ECPrivkey
|
|||||||
from electrum import simple_config, lnutil
|
from electrum import simple_config, lnutil
|
||||||
from electrum.lnaddr import lnencode, LnAddr, lndecode
|
from electrum.lnaddr import lnencode, LnAddr, lndecode
|
||||||
from electrum.bitcoin import COIN, sha256
|
from electrum.bitcoin import COIN, sha256
|
||||||
from electrum.util import bh2u, set_verbosity
|
from electrum.util import bh2u, set_verbosity, create_and_start_event_loop
|
||||||
from electrum.lnpeer import Peer
|
from electrum.lnpeer import Peer
|
||||||
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
|
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
|
||||||
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
|
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
|
||||||
@@ -21,6 +20,7 @@ from electrum.lnworker import LNWorker
|
|||||||
from electrum.lnmsg import encode_msg, decode_msg
|
from electrum.lnmsg import encode_msg, decode_msg
|
||||||
|
|
||||||
from .test_lnchannel import create_test_channels
|
from .test_lnchannel import create_test_channels
|
||||||
|
from . import SequentialTestCase
|
||||||
|
|
||||||
def keypair():
|
def keypair():
|
||||||
priv = ECPrivkey.generate_random_key().get_secret_bytes()
|
priv = ECPrivkey.generate_random_key().get_secret_bytes()
|
||||||
@@ -37,12 +37,12 @@ class MockNetwork:
|
|||||||
def __init__(self, tx_queue):
|
def __init__(self, tx_queue):
|
||||||
self.callbacks = defaultdict(list)
|
self.callbacks = defaultdict(list)
|
||||||
self.lnwatcher = None
|
self.lnwatcher = None
|
||||||
|
self.interface = None
|
||||||
user_config = {}
|
user_config = {}
|
||||||
user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-")
|
user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-")
|
||||||
self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir)
|
self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir)
|
||||||
self.asyncio_loop = asyncio.get_event_loop()
|
self.asyncio_loop = asyncio.get_event_loop()
|
||||||
self.channel_db = ChannelDB(self)
|
self.channel_db = ChannelDB(self)
|
||||||
self.interface = None
|
|
||||||
self.path_finder = LNPathFinder(self.channel_db)
|
self.path_finder = LNPathFinder(self.channel_db)
|
||||||
self.tx_queue = tx_queue
|
self.tx_queue = tx_queue
|
||||||
|
|
||||||
@@ -159,14 +159,23 @@ def transport_pair(name1, name2):
|
|||||||
t2.other_mock_transport = t1
|
t2.other_mock_transport = t1
|
||||||
return t1, t2
|
return t1, t2
|
||||||
|
|
||||||
class TestPeer(unittest.TestCase):
|
class TestPeer(SequentialTestCase):
|
||||||
@staticmethod
|
|
||||||
def setUpClass():
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super().setUpClass()
|
||||||
set_verbosity(True)
|
set_verbosity(True)
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop()
|
||||||
self.alice_channel, self.bob_channel = create_test_channels()
|
self.alice_channel, self.bob_channel = create_test_channels()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
|
||||||
|
self._loop_thread.join(timeout=1)
|
||||||
|
|
||||||
def test_require_data_loss_protect(self):
|
def test_require_data_loss_protect(self):
|
||||||
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
|
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
|
||||||
mock_transport = NoFeaturesTransport('')
|
mock_transport = NoFeaturesTransport('')
|
||||||
@@ -232,8 +241,10 @@ class TestPeer(unittest.TestCase):
|
|||||||
self.assertEqual(await fut, 'Payment received')
|
self.assertEqual(await fut, 'Payment received')
|
||||||
gath.cancel()
|
gath.cancel()
|
||||||
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop())
|
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop())
|
||||||
|
async def f():
|
||||||
|
await gath
|
||||||
with self.assertRaises(asyncio.CancelledError):
|
with self.assertRaises(asyncio.CancelledError):
|
||||||
run(gath)
|
run(f())
|
||||||
|
|
||||||
def test_channel_usage_after_closing(self):
|
def test_channel_usage_after_closing(self):
|
||||||
p1, p2, w1, w2, q1, q2 = self.prepare_peers()
|
p1, p2, w1, w2, q1, q2 = self.prepare_peers()
|
||||||
@@ -253,8 +264,10 @@ class TestPeer(unittest.TestCase):
|
|||||||
peer = w1.peers[route[0].node_id]
|
peer = w1.peers[route[0].node_id]
|
||||||
# AssertionError is ok since we shouldn't use old routes, and the
|
# AssertionError is ok since we shouldn't use old routes, and the
|
||||||
# route finding should fail when channel is closed
|
# route finding should fail when channel is closed
|
||||||
|
async def f():
|
||||||
|
await asyncio.gather(w1._pay_to_route(route, addr, pay_req), p1._message_loop(), p2._message_loop())
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
run(asyncio.gather(w1._pay_to_route(route, addr, pay_req), p1._message_loop(), p2._message_loop()))
|
run(f())
|
||||||
|
|
||||||
def run(coro):
|
def run(coro):
|
||||||
return asyncio.get_event_loop().run_until_complete(coro)
|
return asyncio.run_coroutine_threadsafe(coro, loop=asyncio.get_event_loop()).result()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import tempfile
|
|||||||
import shutil
|
import shutil
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from electrum.util import bh2u, bfh
|
from electrum.util import bh2u, bfh, create_and_start_event_loop
|
||||||
from electrum.lnonion import (OnionHopsDataSingle, new_onion_packet, OnionPerHop,
|
from electrum.lnonion import (OnionHopsDataSingle, new_onion_packet, OnionPerHop,
|
||||||
process_onion_packet, _decode_onion_error, decode_onion_error,
|
process_onion_packet, _decode_onion_error, decode_onion_error,
|
||||||
OnionFailureCode)
|
OnionFailureCode)
|
||||||
@@ -34,11 +34,20 @@ class Test_LNRouter(TestCaseForTestnet):
|
|||||||
cls.electrum_path = tempfile.mkdtemp()
|
cls.electrum_path = tempfile.mkdtemp()
|
||||||
cls.config = SimpleConfig({'electrum_path': cls.electrum_path})
|
cls.config = SimpleConfig({'electrum_path': cls.electrum_path})
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
super().tearDownClass()
|
super().tearDownClass()
|
||||||
shutil.rmtree(cls.electrum_path)
|
shutil.rmtree(cls.electrum_path)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
|
||||||
|
self._loop_thread.join(timeout=1)
|
||||||
|
|
||||||
def test_find_path_for_payment(self):
|
def test_find_path_for_payment(self):
|
||||||
class fake_network:
|
class fake_network:
|
||||||
config = self.config
|
config = self.config
|
||||||
|
|||||||
Reference in New Issue
Block a user