diff --git a/tests/test_lntransport.py b/tests/test_lntransport.py index ac057fbd2..07bd1fc27 100644 --- a/tests/test_lntransport.py +++ b/tests/test_lntransport.py @@ -1,8 +1,10 @@ import asyncio +from typing import List import electrum_ecc as ecc from electrum import util +from electrum import lntransport from electrum.lntransport import LNPeerAddr, LNResponderTransport, LNTransport, extract_nodeid, split_host_port, ConnStringFormatError from electrum.util import OldTaskGroup @@ -71,6 +73,7 @@ class TestLNTransport(ElectrumTestCase): async def cb(reader, writer): t = LNResponderTransport(responder_key.get_secret_bytes(), reader, writer) + transports.append(t) self.assertEqual(await t.handshake(), initiator_key.get_public_key_bytes()) async with OldTaskGroup() as group: await group.spawn(read_messages(t, messages_sent_by_client)) @@ -79,12 +82,14 @@ class TestLNTransport(ElectrumTestCase): async def connect(port: int): peer_addr = LNPeerAddr('127.0.0.1', port, responder_key.get_public_key_bytes()) t = LNTransport(initiator_key.get_secret_bytes(), peer_addr, e_proxy=None) + transports.append(t) await t.handshake() async with OldTaskGroup() as group: await group.spawn(read_messages(t, messages_sent_by_server)) await group.spawn(write_messages(t, messages_sent_by_client)) server_shaked.set() + transports = [] # type: List[lntransport.LNTransportBase] async def f(): server = await asyncio.start_server(cb, '127.0.0.1', port=None) server_port = server.sockets[0].getsockname()[1] @@ -94,6 +99,8 @@ class TestLNTransport(ElectrumTestCase): await group.spawn(responder_shaked.wait()) await group.spawn(server_shaked.wait()) finally: + for t in transports: + t.close() server.close() await f()