verify channel updates in peer's TaskGroup
This commit is contained in:
@@ -35,7 +35,6 @@ from collections import defaultdict
|
||||
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
|
||||
import binascii
|
||||
import base64
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
|
||||
from sqlalchemy.orm.query import Query
|
||||
@@ -224,7 +223,6 @@ class ChannelDB(SqlDB):
|
||||
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
|
||||
self.ca_verifier = LNChannelVerifier(network, self)
|
||||
self.update_counts()
|
||||
self.gossip_queue = asyncio.Queue()
|
||||
|
||||
@sql
|
||||
def update_counts(self):
|
||||
@@ -358,27 +356,46 @@ class ChannelDB(SqlDB):
|
||||
return r.max_timestamp or 0
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
def on_channel_update(self, msg_payloads, trusted=False):
|
||||
if type(msg_payloads) is dict:
|
||||
msg_payloads = [msg_payloads]
|
||||
def get_info_for_updates(self, msg_payloads):
|
||||
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
|
||||
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
|
||||
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
|
||||
new_policies = {}
|
||||
for msg_payload in msg_payloads:
|
||||
short_channel_id = msg_payload['short_channel_id']
|
||||
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
|
||||
continue
|
||||
return channel_infos
|
||||
|
||||
@profiler
|
||||
def filter_channel_updates(self, payloads):
|
||||
# add 'node_id' to payload
|
||||
channel_infos = self.get_info_for_updates(payloads)
|
||||
known = []
|
||||
unknown = []
|
||||
for payload in payloads:
|
||||
short_channel_id = payload['short_channel_id']
|
||||
channel_info = channel_infos.get(short_channel_id)
|
||||
if not channel_info:
|
||||
unknown.append(short_channel_id)
|
||||
continue
|
||||
flags = int.from_bytes(msg_payload['channel_flags'], 'big')
|
||||
flags = int.from_bytes(payload['channel_flags'], 'big')
|
||||
direction = flags & FLAG_DIRECTION
|
||||
node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id
|
||||
if not trusted and not verify_sig_for_channel_update(msg_payload, bytes.fromhex(node_id)):
|
||||
continue
|
||||
short_channel_id = channel_info.short_channel_id
|
||||
node_id = bfh(channel_info.node1_id if direction == 0 else channel_info.node2_id)
|
||||
payload['node_id'] = node_id
|
||||
known.append(payload)
|
||||
return known, unknown
|
||||
|
||||
def add_channel_update(self, payload):
|
||||
# called in tests/test_lnrouter
|
||||
good, bad = self.filter_channel_updates([payload])
|
||||
assert len(bad) == 0
|
||||
self.on_channel_update(good)
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
def on_channel_update(self, msg_payloads):
|
||||
if type(msg_payloads) is dict:
|
||||
msg_payloads = [msg_payloads]
|
||||
new_policies = {}
|
||||
for msg_payload in msg_payloads:
|
||||
short_channel_id = msg_payload['short_channel_id'].hex()
|
||||
node_id = msg_payload['node_id'].hex()
|
||||
new_policy = Policy.from_msg(msg_payload, node_id, short_channel_id)
|
||||
#self.logger.info(f'on_channel_update {datetime.fromtimestamp(new_policy.timestamp).ctime()}')
|
||||
old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none()
|
||||
|
||||
Reference in New Issue
Block a user