sqlite in lnrouter: lnpeer: introduce _gossip_loop for gossip handling separated from message handling
This commit is contained in:
@@ -59,7 +59,6 @@ class Peer(PrintError):
|
||||
self.node_anns = []
|
||||
self.chan_anns = []
|
||||
self.chan_upds = []
|
||||
self.last_chan_db_upd = time.time()
|
||||
self.transport = transport
|
||||
self.pubkey = pubkey
|
||||
self.lnworker = lnworker
|
||||
@@ -209,15 +208,31 @@ class Peer(PrintError):
|
||||
@log_exceptions
|
||||
@handle_disconnect
|
||||
async def main_loop(self):
|
||||
"""
|
||||
This is used in LNWorker and is necessary so that we don't kill the main
|
||||
task group. It is not merged with _main_loop, so that we can test if the
|
||||
correct exceptions are getting thrown using _main_loop.
|
||||
"""
|
||||
await self._main_loop()
|
||||
async with aiorpcx.TaskGroup() as group:
|
||||
await group.spawn(self._gossip_loop())
|
||||
await group.spawn(self._message_loop())
|
||||
|
||||
async def _main_loop(self):
|
||||
"""This is separate from main_loop for the tests."""
|
||||
async def _gossip_loop(self):
|
||||
await self.initialized.wait()
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
if self.node_anns:
|
||||
self.channel_db.on_node_announcement(self.node_anns)
|
||||
self.node_anns = []
|
||||
if self.chan_anns:
|
||||
self.channel_db.on_channel_announcement(self.chan_anns)
|
||||
self.chan_anns = []
|
||||
if self.chan_upds:
|
||||
self.channel_db.on_channel_update(self.chan_upds)
|
||||
self.chan_upds = []
|
||||
need_to_get = self.channel_db.missing_short_chan_ids() #type: Set[int]
|
||||
if need_to_get and not self.receiving_channels:
|
||||
self.print_error('QUERYING SHORT CHANNEL IDS; missing', len(need_to_get), 'channels')
|
||||
zlibencoded = zlib.compress(bfh(''.join(need_to_get)))
|
||||
self.send_message('query_short_channel_ids', chain_hash=bytes.fromhex(bitcoin.rev_hex(constants.net.GENESIS)), len=1+len(zlibencoded), encoded_short_ids=b'\x01' + zlibencoded)
|
||||
self.receiving_channels = True
|
||||
|
||||
async def _message_loop(self):
|
||||
try:
|
||||
await asyncio.wait_for(self.initialize(), 10)
|
||||
except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
|
||||
@@ -227,21 +242,6 @@ class Peer(PrintError):
|
||||
async for msg in self.transport.read_messages():
|
||||
self.process_message(msg)
|
||||
await asyncio.sleep(.01)
|
||||
if time.time() - self.last_chan_db_upd > 5:
|
||||
self.last_chan_db_upd = time.time()
|
||||
self.channel_db.on_node_announcement(self.node_anns)
|
||||
self.node_anns = []
|
||||
self.channel_db.on_channel_announcement(self.chan_anns)
|
||||
self.chan_anns = []
|
||||
self.channel_db.on_channel_update(self.chan_upds)
|
||||
self.chan_upds = []
|
||||
need_to_get = self.channel_db.missing_short_chan_ids() #type: Set[int]
|
||||
if need_to_get and not self.receiving_channels:
|
||||
self.print_error('QUERYING SHORT CHANNEL IDS; ', len(need_to_get))
|
||||
zlibencoded = zlib.compress(b"".join(x.to_bytes(byteorder='big', length=8) for x in need_to_get))
|
||||
self.send_message('query_short_channel_ids', chain_hash=bytes.fromhex(bitcoin.rev_hex(constants.net.GENESIS)), len=1+len(zlibencoded), encoded_short_ids=b'\x01' + zlibencoded)
|
||||
self.receiving_channels = True
|
||||
|
||||
self.ping_if_required()
|
||||
|
||||
def on_reply_short_channel_ids_end(self, payload):
|
||||
|
||||
@@ -347,7 +347,19 @@ class ChannelDB:
|
||||
|
||||
def missing_short_chan_ids(self) -> Set[int]:
|
||||
expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfo.short_channel_id)))
|
||||
return set(DBSession.query(Policy.short_channel_id).filter(expr).all())
|
||||
chan_ids_from_policy = set(x[0] for x in DBSession.query(Policy.short_channel_id).filter(expr).all())
|
||||
if chan_ids_from_policy:
|
||||
return chan_ids_from_policy
|
||||
# fetch channels for node_ids missing in node_info. that will also give us node_announcement
|
||||
expr = not_(ChannelInfo.node1_id.in_(DBSession.query(NodeInfo.node_id)))
|
||||
chan_ids_from_id1 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
|
||||
if chan_ids_from_id1:
|
||||
return chan_ids_from_id1
|
||||
expr = not_(ChannelInfo.node2_id.in_(DBSession.query(NodeInfo.node_id)))
|
||||
chan_ids_from_id2 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
|
||||
if chan_ids_from_id2:
|
||||
return chan_ids_from_id2
|
||||
return set()
|
||||
|
||||
def add_verified_channel_info(self, short_id, capacity):
|
||||
# called from lnchannelverifier
|
||||
@@ -390,6 +402,8 @@ class ChannelDB:
|
||||
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
|
||||
continue
|
||||
channel_info = channel_infos.get(short_channel_id)
|
||||
if not channel_info:
|
||||
continue
|
||||
channel_info.on_channel_update(msg_payload, trusted=trusted)
|
||||
DBSession.commit()
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ class TestPeer(unittest.TestCase):
|
||||
p1 = Peer(mock_lnworker, b"\x00" * 33, mock_transport, request_initial_sync=False)
|
||||
mock_lnworker.peer = p1
|
||||
with self.assertRaises(LightningPeerConnectionClosed):
|
||||
run(asyncio.wait_for(p1._main_loop(), 1))
|
||||
run(asyncio.wait_for(p1._message_loop(), 1))
|
||||
|
||||
def prepare_peers(self):
|
||||
k1, k2 = keypair(), keypair()
|
||||
@@ -231,7 +231,7 @@ class TestPeer(unittest.TestCase):
|
||||
print("HTLC ADDED")
|
||||
self.assertEqual(await fut, 'Payment received')
|
||||
gath.cancel()
|
||||
gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop())
|
||||
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop())
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
run(gath)
|
||||
|
||||
@@ -254,7 +254,7 @@ class TestPeer(unittest.TestCase):
|
||||
# AssertionError is ok since we shouldn't use old routes, and the
|
||||
# route finding should fail when channel is closed
|
||||
with self.assertRaises(AssertionError):
|
||||
run(asyncio.gather(w1._pay_to_route(route, addr, pay_req), p1._main_loop(), p2._main_loop()))
|
||||
run(asyncio.gather(w1._pay_to_route(route, addr, pay_req), p1._message_loop(), p2._message_loop()))
|
||||
|
||||
def run(coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
Reference in New Issue
Block a user