1
0

verify channel updates in peer's TaskGroup

This commit is contained in:
ThomasV
2019-05-15 16:09:23 +02:00
parent 308dc6aa6b
commit 522ce5bb9f
4 changed files with 105 additions and 84 deletions

View File

@@ -35,7 +35,6 @@ from collections import defaultdict
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
import binascii
import base64
import asyncio
from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
from sqlalchemy.orm.query import Query
@@ -224,7 +223,6 @@ class ChannelDB(SqlDB):
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self)
self.update_counts()
self.gossip_queue = asyncio.Queue()
@sql
def update_counts(self):
@@ -358,27 +356,46 @@ class ChannelDB(SqlDB):
return r.max_timestamp or 0
@sql
@profiler
def on_channel_update(self, msg_payloads, trusted=False):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
def get_info_for_updates(self, msg_payloads):
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
new_policies = {}
for msg_payload in msg_payloads:
short_channel_id = msg_payload['short_channel_id']
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
continue
return channel_infos
@profiler
def filter_channel_updates(self, payloads):
# add 'node_id' to payload
channel_infos = self.get_info_for_updates(payloads)
known = []
unknown = []
for payload in payloads:
short_channel_id = payload['short_channel_id']
channel_info = channel_infos.get(short_channel_id)
if not channel_info:
unknown.append(short_channel_id)
continue
flags = int.from_bytes(msg_payload['channel_flags'], 'big')
flags = int.from_bytes(payload['channel_flags'], 'big')
direction = flags & FLAG_DIRECTION
node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id
if not trusted and not verify_sig_for_channel_update(msg_payload, bytes.fromhex(node_id)):
continue
short_channel_id = channel_info.short_channel_id
node_id = bfh(channel_info.node1_id if direction == 0 else channel_info.node2_id)
payload['node_id'] = node_id
known.append(payload)
return known, unknown
def add_channel_update(self, payload):
# called in tests/test_lnrouter
good, bad = self.filter_channel_updates([payload])
assert len(bad) == 0
self.on_channel_update(good)
@sql
@profiler
def on_channel_update(self, msg_payloads):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
new_policies = {}
for msg_payload in msg_payloads:
short_channel_id = msg_payload['short_channel_id'].hex()
node_id = msg_payload['node_id'].hex()
new_policy = Policy.from_msg(msg_payload, node_id, short_channel_id)
#self.logger.info(f'on_channel_update {datetime.fromtimestamp(new_policy.timestamp).ctime()}')
old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none()