1
0

LNGossip: sync channel db using query_channel_range

This commit is contained in:
ThomasV
2019-05-13 14:30:02 +02:00
parent 95376226e8
commit 1011245c5e
3 changed files with 168 additions and 59 deletions

View File

@@ -57,9 +57,7 @@ class Peer(Logger):
def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase):
self.initialized = asyncio.Event()
self.node_anns = []
self.chan_anns = []
self.chan_upds = []
self.querying_lock = asyncio.Lock()
self.transport = transport
self.pubkey = pubkey
self.lnworker = lnworker
@@ -70,6 +68,7 @@ class Peer(Logger):
self.lnwatcher = lnworker.network.lnwatcher
self.channel_db = lnworker.network.channel_db
self.ping_time = 0
self.reply_channel_range = asyncio.Queue()
self.shutdown_received = defaultdict(asyncio.Future)
self.channel_accepted = defaultdict(asyncio.Queue)
self.channel_reestablished = defaultdict(asyncio.Future)
@@ -89,7 +88,7 @@ class Peer(Logger):
def send_message(self, message_name: str, **kwargs):
assert type(message_name) is str
self.logger.info(f"Sending {message_name.upper()}")
self.logger.debug(f"Sending {message_name.upper()}")
self.transport.send_bytes(encode_msg(message_name, **kwargs))
async def initialize(self):
@@ -177,13 +176,13 @@ class Peer(Logger):
self.initialized.set()
def on_node_announcement(self, payload):
self.node_anns.append(payload)
self.channel_db.node_anns.append(payload)
def on_channel_update(self, payload):
self.chan_upds.append(payload)
self.channel_db.chan_upds.append(payload)
def on_channel_announcement(self, payload):
self.chan_anns.append(payload)
self.channel_db.chan_anns.append(payload)
def on_announcement_signatures(self, payload):
channel_id = payload['channel_id']
@@ -207,15 +206,11 @@ class Peer(Logger):
@handle_disconnect
async def main_loop(self):
async with aiorpcx.TaskGroup() as group:
await group.spawn(self._gossip_loop())
await group.spawn(self._message_loop())
# kill group if the peer times out
await group.spawn(asyncio.wait_for(self.initialized.wait(), 10))
@log_exceptions
async def _gossip_loop(self):
await self.initialized.wait()
timestamp = self.channel_db.get_last_timestamp()
def request_gossip(self, timestamp=0):
if timestamp == 0:
self.logger.info('requesting whole channel graph')
else:
@@ -225,28 +220,47 @@ class Peer(Logger):
chain_hash=constants.net.rev_genesis_bytes(),
first_timestamp=timestamp,
timestamp_range=b'\xff'*4)
while True:
await asyncio.sleep(5)
if self.node_anns:
self.channel_db.on_node_announcement(self.node_anns)
self.node_anns = []
if self.chan_anns:
self.channel_db.on_channel_announcement(self.chan_anns)
self.chan_anns = []
if self.chan_upds:
self.channel_db.on_channel_update(self.chan_upds)
self.chan_upds = []
# todo: enable when db is fixed
#need_to_get = sorted(self.channel_db.missing_short_chan_ids())
#if need_to_get and not self.receiving_channels:
# self.logger.info(f'missing {len(need_to_get)} channels')
# zlibencoded = zlib.compress(bfh(''.join(need_to_get[0:100])))
# self.send_message(
# 'query_short_channel_ids',
# chain_hash=constants.net.rev_genesis_bytes(),
# len=1+len(zlibencoded),
# encoded_short_ids=b'\x01' + zlibencoded)
# self.receiving_channels = True
def query_channel_range(self, index, num):
self.logger.info(f'query channel range')
self.send_message(
'query_channel_range',
chain_hash=constants.net.rev_genesis_bytes(),
first_blocknum=index,
number_of_blocks=num)
def encode_short_ids(self, ids):
return chr(1) + zlib.compress(bfh(''.join(ids)))
def decode_short_ids(self, encoded):
if encoded[0] == 0:
decoded = encoded[1:]
elif encoded[0] == 1:
decoded = zlib.decompress(encoded[1:])
else:
raise BaseException('zlib')
ids = [decoded[i:i+8] for i in range(0, len(decoded), 8)]
return ids
def on_reply_channel_range(self, payload):
first = int.from_bytes(payload['first_blocknum'], 'big')
num = int.from_bytes(payload['number_of_blocks'], 'big')
complete = bool(payload['complete'])
encoded = payload['encoded_short_ids']
ids = self.decode_short_ids(encoded)
self.reply_channel_range.put_nowait((first, num, complete, ids))
async def query_short_channel_ids(self, ids, compressed=True):
await self.querying_lock.acquire()
#self.logger.info('querying {} short_channel_ids'.format(len(ids)))
s = b''.join(ids)
encoded = zlib.compress(s) if compressed else s
prefix = b'\x01' if compressed else b'\x00'
self.send_message(
'query_short_channel_ids',
chain_hash=constants.net.rev_genesis_bytes(),
len=1+len(encoded),
encoded_short_ids=prefix+encoded)
async def _message_loop(self):
try:
@@ -260,7 +274,7 @@ class Peer(Logger):
self.ping_if_required()
def on_reply_short_channel_ids_end(self, payload):
self.receiving_channels = False
self.querying_lock.release()
def close_and_cleanup(self):
try: