LNGossip: sync channel db using query_channel_range
This commit is contained in:
@@ -57,9 +57,7 @@ class Peer(Logger):
|
||||
|
||||
def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase):
|
||||
self.initialized = asyncio.Event()
|
||||
self.node_anns = []
|
||||
self.chan_anns = []
|
||||
self.chan_upds = []
|
||||
self.querying_lock = asyncio.Lock()
|
||||
self.transport = transport
|
||||
self.pubkey = pubkey
|
||||
self.lnworker = lnworker
|
||||
@@ -70,6 +68,7 @@ class Peer(Logger):
|
||||
self.lnwatcher = lnworker.network.lnwatcher
|
||||
self.channel_db = lnworker.network.channel_db
|
||||
self.ping_time = 0
|
||||
self.reply_channel_range = asyncio.Queue()
|
||||
self.shutdown_received = defaultdict(asyncio.Future)
|
||||
self.channel_accepted = defaultdict(asyncio.Queue)
|
||||
self.channel_reestablished = defaultdict(asyncio.Future)
|
||||
@@ -89,7 +88,7 @@ class Peer(Logger):
|
||||
|
||||
def send_message(self, message_name: str, **kwargs):
|
||||
assert type(message_name) is str
|
||||
self.logger.info(f"Sending {message_name.upper()}")
|
||||
self.logger.debug(f"Sending {message_name.upper()}")
|
||||
self.transport.send_bytes(encode_msg(message_name, **kwargs))
|
||||
|
||||
async def initialize(self):
|
||||
@@ -177,13 +176,13 @@ class Peer(Logger):
|
||||
self.initialized.set()
|
||||
|
||||
def on_node_announcement(self, payload):
|
||||
self.node_anns.append(payload)
|
||||
self.channel_db.node_anns.append(payload)
|
||||
|
||||
def on_channel_update(self, payload):
|
||||
self.chan_upds.append(payload)
|
||||
self.channel_db.chan_upds.append(payload)
|
||||
|
||||
def on_channel_announcement(self, payload):
|
||||
self.chan_anns.append(payload)
|
||||
self.channel_db.chan_anns.append(payload)
|
||||
|
||||
def on_announcement_signatures(self, payload):
|
||||
channel_id = payload['channel_id']
|
||||
@@ -207,15 +206,11 @@ class Peer(Logger):
|
||||
@handle_disconnect
|
||||
async def main_loop(self):
|
||||
async with aiorpcx.TaskGroup() as group:
|
||||
await group.spawn(self._gossip_loop())
|
||||
await group.spawn(self._message_loop())
|
||||
# kill group if the peer times out
|
||||
await group.spawn(asyncio.wait_for(self.initialized.wait(), 10))
|
||||
|
||||
@log_exceptions
|
||||
async def _gossip_loop(self):
|
||||
await self.initialized.wait()
|
||||
timestamp = self.channel_db.get_last_timestamp()
|
||||
def request_gossip(self, timestamp=0):
|
||||
if timestamp == 0:
|
||||
self.logger.info('requesting whole channel graph')
|
||||
else:
|
||||
@@ -225,28 +220,47 @@ class Peer(Logger):
|
||||
chain_hash=constants.net.rev_genesis_bytes(),
|
||||
first_timestamp=timestamp,
|
||||
timestamp_range=b'\xff'*4)
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
if self.node_anns:
|
||||
self.channel_db.on_node_announcement(self.node_anns)
|
||||
self.node_anns = []
|
||||
if self.chan_anns:
|
||||
self.channel_db.on_channel_announcement(self.chan_anns)
|
||||
self.chan_anns = []
|
||||
if self.chan_upds:
|
||||
self.channel_db.on_channel_update(self.chan_upds)
|
||||
self.chan_upds = []
|
||||
# todo: enable when db is fixed
|
||||
#need_to_get = sorted(self.channel_db.missing_short_chan_ids())
|
||||
#if need_to_get and not self.receiving_channels:
|
||||
# self.logger.info(f'missing {len(need_to_get)} channels')
|
||||
# zlibencoded = zlib.compress(bfh(''.join(need_to_get[0:100])))
|
||||
# self.send_message(
|
||||
# 'query_short_channel_ids',
|
||||
# chain_hash=constants.net.rev_genesis_bytes(),
|
||||
# len=1+len(zlibencoded),
|
||||
# encoded_short_ids=b'\x01' + zlibencoded)
|
||||
# self.receiving_channels = True
|
||||
|
||||
def query_channel_range(self, index, num):
|
||||
self.logger.info(f'query channel range')
|
||||
self.send_message(
|
||||
'query_channel_range',
|
||||
chain_hash=constants.net.rev_genesis_bytes(),
|
||||
first_blocknum=index,
|
||||
number_of_blocks=num)
|
||||
|
||||
def encode_short_ids(self, ids):
|
||||
return chr(1) + zlib.compress(bfh(''.join(ids)))
|
||||
|
||||
def decode_short_ids(self, encoded):
|
||||
if encoded[0] == 0:
|
||||
decoded = encoded[1:]
|
||||
elif encoded[0] == 1:
|
||||
decoded = zlib.decompress(encoded[1:])
|
||||
else:
|
||||
raise BaseException('zlib')
|
||||
ids = [decoded[i:i+8] for i in range(0, len(decoded), 8)]
|
||||
return ids
|
||||
|
||||
def on_reply_channel_range(self, payload):
|
||||
first = int.from_bytes(payload['first_blocknum'], 'big')
|
||||
num = int.from_bytes(payload['number_of_blocks'], 'big')
|
||||
complete = bool(payload['complete'])
|
||||
encoded = payload['encoded_short_ids']
|
||||
ids = self.decode_short_ids(encoded)
|
||||
self.reply_channel_range.put_nowait((first, num, complete, ids))
|
||||
|
||||
async def query_short_channel_ids(self, ids, compressed=True):
|
||||
await self.querying_lock.acquire()
|
||||
#self.logger.info('querying {} short_channel_ids'.format(len(ids)))
|
||||
s = b''.join(ids)
|
||||
encoded = zlib.compress(s) if compressed else s
|
||||
prefix = b'\x01' if compressed else b'\x00'
|
||||
self.send_message(
|
||||
'query_short_channel_ids',
|
||||
chain_hash=constants.net.rev_genesis_bytes(),
|
||||
len=1+len(encoded),
|
||||
encoded_short_ids=prefix+encoded)
|
||||
|
||||
async def _message_loop(self):
|
||||
try:
|
||||
@@ -260,7 +274,7 @@ class Peer(Logger):
|
||||
self.ping_if_required()
|
||||
|
||||
def on_reply_short_channel_ids_end(self, payload):
|
||||
self.receiving_channels = False
|
||||
self.querying_lock.release()
|
||||
|
||||
def close_and_cleanup(self):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user