One usecase is perhaps to save space if using trampoline anyway... more importantly, if using gossip, LNGossip is heavily filtering what messages we request and get, and e.g. can missing new NodeAnnouncements, etc, and this is a quick-and-dirty workaround to force a fresh start.
918 lines
38 KiB
Python
918 lines
38 KiB
Python
# -*- coding: utf-8 -*-
|
|
#
|
|
# Electrum - lightweight Bitcoin client
|
|
# Copyright (C) 2018 The Electrum developers
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person
|
|
# obtaining a copy of this software and associated documentation files
|
|
# (the "Software"), to deal in the Software without restriction,
|
|
# including without limitation the rights to use, copy, modify, merge,
|
|
# publish, distribute, sublicense, and/or sell copies of the Software,
|
|
# and to permit persons to whom the Software is furnished to do so,
|
|
# subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be
|
|
# included in all copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
|
|
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
|
|
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
import time
|
|
import random
|
|
import os
|
|
from collections import defaultdict
|
|
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
|
|
import binascii
|
|
import base64
|
|
import asyncio
|
|
import threading
|
|
from enum import IntEnum
|
|
|
|
from aiorpcx import NetAddress
|
|
|
|
from .sql_db import SqlDB, sql
|
|
from . import constants, util
|
|
from .util import bh2u, profiler, get_headers_dir, is_ip_address, json_normalize
|
|
from .logging import Logger
|
|
from .lnutil import (LNPeerAddr, format_short_channel_id, ShortChannelID,
|
|
validate_features, IncompatibleOrInsaneFeatures, InvalidGossipMsg)
|
|
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
|
|
from .lnmsg import decode_msg
|
|
from . import ecc
|
|
from .crypto import sha256d
|
|
|
|
if TYPE_CHECKING:
|
|
from .network import Network
|
|
from .lnchannel import Channel
|
|
from .lnrouter import RouteEdge
|
|
from .simple_config import SimpleConfig
|
|
|
|
|
|
FLAG_DISABLE = 1 << 1
|
|
FLAG_DIRECTION = 1 << 0
|
|
|
|
|
|
class ChannelInfo(NamedTuple):
|
|
short_channel_id: ShortChannelID
|
|
node1_id: bytes
|
|
node2_id: bytes
|
|
capacity_sat: Optional[int]
|
|
|
|
@staticmethod
|
|
def from_msg(payload: dict) -> 'ChannelInfo':
|
|
features = int.from_bytes(payload['features'], 'big')
|
|
validate_features(features)
|
|
channel_id = payload['short_channel_id']
|
|
node_id_1 = payload['node_id_1']
|
|
node_id_2 = payload['node_id_2']
|
|
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
|
|
capacity_sat = None
|
|
return ChannelInfo(
|
|
short_channel_id = ShortChannelID.normalize(channel_id),
|
|
node1_id = node_id_1,
|
|
node2_id = node_id_2,
|
|
capacity_sat = capacity_sat
|
|
)
|
|
|
|
@staticmethod
|
|
def from_raw_msg(raw: bytes) -> 'ChannelInfo':
|
|
payload_dict = decode_msg(raw)[1]
|
|
return ChannelInfo.from_msg(payload_dict)
|
|
|
|
@staticmethod
|
|
def from_route_edge(route_edge: 'RouteEdge') -> 'ChannelInfo':
|
|
node1_id, node2_id = sorted([route_edge.start_node, route_edge.end_node])
|
|
return ChannelInfo(
|
|
short_channel_id=route_edge.short_channel_id,
|
|
node1_id=node1_id,
|
|
node2_id=node2_id,
|
|
capacity_sat=None,
|
|
)
|
|
|
|
|
|
class Policy(NamedTuple):
|
|
key: bytes
|
|
cltv_expiry_delta: int
|
|
htlc_minimum_msat: int
|
|
htlc_maximum_msat: Optional[int]
|
|
fee_base_msat: int
|
|
fee_proportional_millionths: int
|
|
channel_flags: int
|
|
message_flags: int
|
|
timestamp: int
|
|
|
|
@staticmethod
|
|
def from_msg(payload: dict) -> 'Policy':
|
|
return Policy(
|
|
key = payload['short_channel_id'] + payload['start_node'],
|
|
cltv_expiry_delta = payload['cltv_expiry_delta'],
|
|
htlc_minimum_msat = payload['htlc_minimum_msat'],
|
|
htlc_maximum_msat = payload.get('htlc_maximum_msat', None),
|
|
fee_base_msat = payload['fee_base_msat'],
|
|
fee_proportional_millionths = payload['fee_proportional_millionths'],
|
|
message_flags = int.from_bytes(payload['message_flags'], "big"),
|
|
channel_flags = int.from_bytes(payload['channel_flags'], "big"),
|
|
timestamp = payload['timestamp'],
|
|
)
|
|
|
|
@staticmethod
|
|
def from_raw_msg(key:bytes, raw: bytes) -> 'Policy':
|
|
payload = decode_msg(raw)[1]
|
|
payload['start_node'] = key[8:]
|
|
return Policy.from_msg(payload)
|
|
|
|
@staticmethod
|
|
def from_route_edge(route_edge: 'RouteEdge') -> 'Policy':
|
|
return Policy(
|
|
key=route_edge.short_channel_id + route_edge.start_node,
|
|
cltv_expiry_delta=route_edge.cltv_expiry_delta,
|
|
htlc_minimum_msat=0,
|
|
htlc_maximum_msat=None,
|
|
fee_base_msat=route_edge.fee_base_msat,
|
|
fee_proportional_millionths=route_edge.fee_proportional_millionths,
|
|
channel_flags=0,
|
|
message_flags=0,
|
|
timestamp=0,
|
|
)
|
|
|
|
def is_disabled(self):
|
|
return self.channel_flags & FLAG_DISABLE
|
|
|
|
@property
|
|
def short_channel_id(self) -> ShortChannelID:
|
|
return ShortChannelID.normalize(self.key[0:8])
|
|
|
|
@property
|
|
def start_node(self) -> bytes:
|
|
return self.key[8:]
|
|
|
|
|
|
class NodeInfo(NamedTuple):
|
|
node_id: bytes
|
|
features: int
|
|
timestamp: int
|
|
alias: str
|
|
|
|
@staticmethod
|
|
def from_msg(payload) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]:
|
|
node_id = payload['node_id']
|
|
features = int.from_bytes(payload['features'], "big")
|
|
validate_features(features)
|
|
addresses = NodeInfo.parse_addresses_field(payload['addresses'])
|
|
peer_addrs = []
|
|
for host, port in addresses:
|
|
try:
|
|
peer_addrs.append(LNPeerAddr(host=host, port=port, pubkey=node_id))
|
|
except ValueError:
|
|
pass
|
|
alias = payload['alias'].rstrip(b'\x00')
|
|
try:
|
|
alias = alias.decode('utf8')
|
|
except:
|
|
alias = ''
|
|
timestamp = payload['timestamp']
|
|
node_info = NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias)
|
|
return node_info, peer_addrs
|
|
|
|
@staticmethod
|
|
def from_raw_msg(raw: bytes) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]:
|
|
payload_dict = decode_msg(raw)[1]
|
|
return NodeInfo.from_msg(payload_dict)
|
|
|
|
@staticmethod
|
|
def parse_addresses_field(addresses_field):
|
|
buf = addresses_field
|
|
def read(n):
|
|
nonlocal buf
|
|
data, buf = buf[0:n], buf[n:]
|
|
return data
|
|
addresses = []
|
|
while buf:
|
|
atype = ord(read(1))
|
|
if atype == 0:
|
|
pass
|
|
elif atype == 1: # IPv4
|
|
ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
|
|
port = int.from_bytes(read(2), 'big')
|
|
if is_ip_address(ipv4_addr) and port != 0:
|
|
addresses.append((ipv4_addr, port))
|
|
elif atype == 2: # IPv6
|
|
ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)])
|
|
ipv6_addr = ipv6_addr.decode('ascii')
|
|
port = int.from_bytes(read(2), 'big')
|
|
if is_ip_address(ipv6_addr) and port != 0:
|
|
addresses.append((ipv6_addr, port))
|
|
elif atype == 3: # onion v2
|
|
host = base64.b32encode(read(10)) + b'.onion'
|
|
host = host.decode('ascii').lower()
|
|
port = int.from_bytes(read(2), 'big')
|
|
addresses.append((host, port))
|
|
elif atype == 4: # onion v3
|
|
host = base64.b32encode(read(35)) + b'.onion'
|
|
host = host.decode('ascii').lower()
|
|
port = int.from_bytes(read(2), 'big')
|
|
addresses.append((host, port))
|
|
else:
|
|
# unknown address type
|
|
# we don't know how long it is -> have to escape
|
|
# if there are other addresses we could have parsed later, they are lost.
|
|
break
|
|
return addresses
|
|
|
|
|
|
class UpdateStatus(IntEnum):
|
|
ORPHANED = 0
|
|
EXPIRED = 1
|
|
DEPRECATED = 2
|
|
UNCHANGED = 3
|
|
GOOD = 4
|
|
|
|
class CategorizedChannelUpdates(NamedTuple):
|
|
orphaned: List # no channel announcement for channel update
|
|
expired: List # update older than two weeks
|
|
deprecated: List # update older than database entry
|
|
unchanged: List # unchanged policies
|
|
good: List # good updates
|
|
|
|
|
|
def get_mychannel_info(short_channel_id: ShortChannelID,
|
|
my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]:
|
|
chan = my_channels.get(short_channel_id)
|
|
if not chan:
|
|
return
|
|
ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs())
|
|
return ci._replace(capacity_sat=chan.constraints.capacity)
|
|
|
|
def get_mychannel_policy(short_channel_id: bytes, node_id: bytes,
|
|
my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[Policy]:
|
|
chan = my_channels.get(short_channel_id) # type: Optional[Channel]
|
|
if not chan:
|
|
return
|
|
if node_id == chan.node_id: # incoming direction (to us)
|
|
remote_update_raw = chan.get_remote_update()
|
|
if not remote_update_raw:
|
|
return
|
|
now = int(time.time())
|
|
remote_update_decoded = decode_msg(remote_update_raw)[1]
|
|
remote_update_decoded['timestamp'] = now
|
|
remote_update_decoded['start_node'] = node_id
|
|
return Policy.from_msg(remote_update_decoded)
|
|
elif node_id == chan.get_local_pubkey(): # outgoing direction (from us)
|
|
local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1]
|
|
local_update_decoded['start_node'] = node_id
|
|
return Policy.from_msg(local_update_decoded)
|
|
|
|
|
|
create_channel_info = """
|
|
CREATE TABLE IF NOT EXISTS channel_info (
|
|
short_channel_id BLOB(8),
|
|
msg BLOB,
|
|
PRIMARY KEY(short_channel_id)
|
|
)"""
|
|
|
|
create_policy = """
|
|
CREATE TABLE IF NOT EXISTS policy (
|
|
key BLOB(41),
|
|
msg BLOB,
|
|
PRIMARY KEY(key)
|
|
)"""
|
|
|
|
create_address = """
|
|
CREATE TABLE IF NOT EXISTS address (
|
|
node_id BLOB(33),
|
|
host STRING(256),
|
|
port INTEGER NOT NULL,
|
|
timestamp INTEGER,
|
|
PRIMARY KEY(node_id, host, port)
|
|
)"""
|
|
|
|
create_node_info = """
|
|
CREATE TABLE IF NOT EXISTS node_info (
|
|
node_id BLOB(33),
|
|
msg BLOB,
|
|
PRIMARY KEY(node_id)
|
|
)"""
|
|
|
|
|
|
class ChannelDB(SqlDB):
|
|
|
|
NUM_MAX_RECENT_PEERS = 20
|
|
|
|
def __init__(self, network: 'Network'):
|
|
path = self.get_file_path(network.config)
|
|
super().__init__(network.asyncio_loop, path, commit_interval=100)
|
|
self.lock = threading.RLock()
|
|
self.num_nodes = 0
|
|
self.num_channels = 0
|
|
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
|
|
self.ca_verifier = LNChannelVerifier(network, self)
|
|
|
|
# initialized in load_data
|
|
# note: modify/iterate needs self.lock
|
|
self._channels = {} # type: Dict[ShortChannelID, ChannelInfo]
|
|
self._policies = {} # type: Dict[Tuple[bytes, ShortChannelID], Policy] # (node_id, scid) -> Policy
|
|
self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo
|
|
# node_id -> NetAddress -> timestamp
|
|
self._addresses = defaultdict(dict) # type: Dict[bytes, Dict[NetAddress, int]]
|
|
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
|
|
self._recent_peers = [] # type: List[bytes] # list of node_ids
|
|
self._chans_with_0_policies = set() # type: Set[ShortChannelID]
|
|
self._chans_with_1_policies = set() # type: Set[ShortChannelID]
|
|
self._chans_with_2_policies = set() # type: Set[ShortChannelID]
|
|
|
|
self.data_loaded = asyncio.Event()
|
|
self.network = network # only for callback
|
|
|
|
@classmethod
|
|
def get_file_path(cls, config: 'SimpleConfig') -> str:
|
|
return os.path.join(get_headers_dir(config), 'gossip_db')
|
|
|
|
def update_counts(self):
|
|
self.num_nodes = len(self._nodes)
|
|
self.num_channels = len(self._channels)
|
|
self.num_policies = len(self._policies)
|
|
util.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies)
|
|
util.trigger_callback('ln_gossip_sync_progress')
|
|
|
|
def get_channel_ids(self):
|
|
with self.lock:
|
|
return set(self._channels.keys())
|
|
|
|
def add_recent_peer(self, peer: LNPeerAddr):
|
|
now = int(time.time())
|
|
node_id = peer.pubkey
|
|
with self.lock:
|
|
self._addresses[node_id][peer.net_addr()] = now
|
|
# list is ordered
|
|
if node_id in self._recent_peers:
|
|
self._recent_peers.remove(node_id)
|
|
self._recent_peers.insert(0, node_id)
|
|
self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS]
|
|
self._db_save_node_address(peer, now)
|
|
|
|
def get_200_randomly_sorted_nodes_not_in(self, node_ids):
|
|
with self.lock:
|
|
unshuffled = set(self._nodes.keys()) - node_ids
|
|
return random.sample(list(unshuffled), min(200, len(unshuffled)))
|
|
|
|
def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
|
|
"""Returns latest address we successfully connected to, for given node."""
|
|
addr_to_ts = self._addresses.get(node_id)
|
|
if not addr_to_ts:
|
|
return None
|
|
addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0]
|
|
try:
|
|
return LNPeerAddr(str(addr.host), addr.port, node_id)
|
|
except ValueError:
|
|
return None
|
|
|
|
def get_recent_peers(self):
|
|
if not self.data_loaded.is_set():
|
|
raise Exception("channelDB data not loaded yet!")
|
|
with self.lock:
|
|
ret = [self.get_last_good_address(node_id)
|
|
for node_id in self._recent_peers]
|
|
return ret
|
|
|
|
# note: currently channel announcements are trusted by default (trusted=True);
|
|
# they are not SPV-verified. Verifying them would make the gossip sync
|
|
# even slower; especially as servers will start throttling us.
|
|
# It would probably put significant strain on servers if all clients
|
|
# verified the complete gossip.
|
|
def add_channel_announcements(self, msg_payloads, *, trusted=True):
|
|
# note: signatures have already been verified.
|
|
if type(msg_payloads) is dict:
|
|
msg_payloads = [msg_payloads]
|
|
added = 0
|
|
for msg in msg_payloads:
|
|
short_channel_id = ShortChannelID(msg['short_channel_id'])
|
|
if short_channel_id in self._channels:
|
|
continue
|
|
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
|
|
self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash'])))
|
|
continue
|
|
try:
|
|
channel_info = ChannelInfo.from_msg(msg)
|
|
except IncompatibleOrInsaneFeatures as e:
|
|
self.logger.info(f"unknown or insane feature bits: {e!r}")
|
|
continue
|
|
if trusted:
|
|
added += 1
|
|
self.add_verified_channel_info(msg)
|
|
else:
|
|
added += self.ca_verifier.add_new_channel_info(short_channel_id, msg)
|
|
|
|
self.update_counts()
|
|
self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads)))
|
|
|
|
def add_verified_channel_info(self, msg: dict, *, capacity_sat: int = None) -> None:
|
|
try:
|
|
channel_info = ChannelInfo.from_msg(msg)
|
|
except IncompatibleOrInsaneFeatures:
|
|
return
|
|
channel_info = channel_info._replace(capacity_sat=capacity_sat)
|
|
with self.lock:
|
|
self._channels[channel_info.short_channel_id] = channel_info
|
|
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
|
|
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
|
|
self._update_num_policies_for_chan(channel_info.short_channel_id)
|
|
if 'raw' in msg:
|
|
self._db_save_channel(channel_info.short_channel_id, msg['raw'])
|
|
|
|
def policy_changed(self, old_policy: Policy, new_policy: Policy, verbose: bool) -> bool:
|
|
changed = False
|
|
if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
|
|
if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
|
|
if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
|
|
if old_policy.fee_base_msat != new_policy.fee_base_msat:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
|
|
if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
|
|
if old_policy.channel_flags != new_policy.channel_flags:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
|
|
if old_policy.message_flags != new_policy.message_flags:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'message_flags: {old_policy.message_flags} -> {new_policy.message_flags}')
|
|
if not changed and verbose:
|
|
self.logger.info(f'policy unchanged: {old_policy.timestamp} -> {new_policy.timestamp}')
|
|
return changed
|
|
|
|
def add_channel_update(
|
|
self, payload, *, max_age=None, verify=True, verbose=True) -> UpdateStatus:
|
|
now = int(time.time())
|
|
short_channel_id = ShortChannelID(payload['short_channel_id'])
|
|
timestamp = payload['timestamp']
|
|
if max_age and now - timestamp > max_age:
|
|
return UpdateStatus.EXPIRED
|
|
if timestamp - now > 60:
|
|
return UpdateStatus.DEPRECATED
|
|
channel_info = self._channels.get(short_channel_id)
|
|
if not channel_info:
|
|
return UpdateStatus.ORPHANED
|
|
flags = int.from_bytes(payload['channel_flags'], 'big')
|
|
direction = flags & FLAG_DIRECTION
|
|
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
|
|
payload['start_node'] = start_node
|
|
# compare updates to existing database entries
|
|
short_channel_id = ShortChannelID(payload['short_channel_id'])
|
|
key = (start_node, short_channel_id)
|
|
old_policy = self._policies.get(key)
|
|
if old_policy and timestamp <= old_policy.timestamp + 60:
|
|
return UpdateStatus.DEPRECATED
|
|
if verify:
|
|
self.verify_channel_update(payload)
|
|
policy = Policy.from_msg(payload)
|
|
with self.lock:
|
|
self._policies[key] = policy
|
|
self._update_num_policies_for_chan(short_channel_id)
|
|
if 'raw' in payload:
|
|
self._db_save_policy(policy.key, payload['raw'])
|
|
if old_policy and not self.policy_changed(old_policy, policy, verbose):
|
|
return UpdateStatus.UNCHANGED
|
|
else:
|
|
return UpdateStatus.GOOD
|
|
|
|
def add_channel_updates(self, payloads, max_age=None) -> CategorizedChannelUpdates:
|
|
orphaned = []
|
|
expired = []
|
|
deprecated = []
|
|
unchanged = []
|
|
good = []
|
|
for payload in payloads:
|
|
r = self.add_channel_update(payload, max_age=max_age, verbose=False, verify=True)
|
|
if r == UpdateStatus.ORPHANED:
|
|
orphaned.append(payload)
|
|
elif r == UpdateStatus.EXPIRED:
|
|
expired.append(payload)
|
|
elif r == UpdateStatus.DEPRECATED:
|
|
deprecated.append(payload)
|
|
elif r == UpdateStatus.UNCHANGED:
|
|
unchanged.append(payload)
|
|
elif r == UpdateStatus.GOOD:
|
|
good.append(payload)
|
|
self.update_counts()
|
|
return CategorizedChannelUpdates(
|
|
orphaned=orphaned,
|
|
expired=expired,
|
|
deprecated=deprecated,
|
|
unchanged=unchanged,
|
|
good=good)
|
|
|
|
|
|
def create_database(self):
|
|
c = self.conn.cursor()
|
|
c.execute(create_node_info)
|
|
c.execute(create_address)
|
|
c.execute(create_policy)
|
|
c.execute(create_channel_info)
|
|
self.conn.commit()
|
|
|
|
@sql
|
|
def _db_save_policy(self, key: bytes, msg: bytes):
|
|
# 'msg' is a 'channel_update' message
|
|
c = self.conn.cursor()
|
|
c.execute("""REPLACE INTO policy (key, msg) VALUES (?,?)""", [key, msg])
|
|
|
|
@sql
|
|
def _db_delete_policy(self, node_id: bytes, short_channel_id: ShortChannelID):
|
|
key = short_channel_id + node_id
|
|
c = self.conn.cursor()
|
|
c.execute("""DELETE FROM policy WHERE key=?""", (key,))
|
|
|
|
@sql
|
|
def _db_save_channel(self, short_channel_id: ShortChannelID, msg: bytes):
|
|
# 'msg' is a 'channel_announcement' message
|
|
c = self.conn.cursor()
|
|
c.execute("REPLACE INTO channel_info (short_channel_id, msg) VALUES (?,?)", [short_channel_id, msg])
|
|
|
|
@sql
|
|
def _db_delete_channel(self, short_channel_id: ShortChannelID):
|
|
c = self.conn.cursor()
|
|
c.execute("""DELETE FROM channel_info WHERE short_channel_id=?""", (short_channel_id,))
|
|
|
|
@sql
|
|
def _db_save_node_info(self, node_id: bytes, msg: bytes):
|
|
# 'msg' is a 'node_announcement' message
|
|
c = self.conn.cursor()
|
|
c.execute("REPLACE INTO node_info (node_id, msg) VALUES (?,?)", [node_id, msg])
|
|
|
|
@sql
|
|
def _db_save_node_address(self, peer: LNPeerAddr, timestamp: int):
|
|
c = self.conn.cursor()
|
|
c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)",
|
|
(peer.pubkey, peer.host, peer.port, timestamp))
|
|
|
|
@sql
|
|
def _db_save_node_addresses(self, node_addresses: Sequence[LNPeerAddr]):
|
|
c = self.conn.cursor()
|
|
for addr in node_addresses:
|
|
c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.pubkey, addr.host, addr.port))
|
|
r = c.fetchall()
|
|
if r == []:
|
|
c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.pubkey, addr.host, addr.port, 0))
|
|
|
|
@classmethod
|
|
def verify_channel_update(cls, payload, *, start_node: bytes = None) -> None:
|
|
short_channel_id = payload['short_channel_id']
|
|
short_channel_id = ShortChannelID(short_channel_id)
|
|
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
|
|
raise InvalidGossipMsg('wrong chain hash')
|
|
start_node = payload.get('start_node', None) or start_node
|
|
assert start_node is not None
|
|
if not verify_sig_for_channel_update(payload, start_node):
|
|
raise InvalidGossipMsg(f'failed verifying channel update for {short_channel_id}')
|
|
|
|
@classmethod
|
|
def verify_channel_announcement(cls, payload) -> None:
|
|
h = sha256d(payload['raw'][2+256:])
|
|
pubkeys = [payload['node_id_1'], payload['node_id_2'], payload['bitcoin_key_1'], payload['bitcoin_key_2']]
|
|
sigs = [payload['node_signature_1'], payload['node_signature_2'], payload['bitcoin_signature_1'], payload['bitcoin_signature_2']]
|
|
for pubkey, sig in zip(pubkeys, sigs):
|
|
if not ecc.verify_signature(pubkey, sig, h):
|
|
raise InvalidGossipMsg('signature failed')
|
|
|
|
@classmethod
|
|
def verify_node_announcement(cls, payload) -> None:
|
|
pubkey = payload['node_id']
|
|
signature = payload['signature']
|
|
h = sha256d(payload['raw'][66:])
|
|
if not ecc.verify_signature(pubkey, signature, h):
|
|
raise InvalidGossipMsg('signature failed')
|
|
|
|
def add_node_announcements(self, msg_payloads):
|
|
# note: signatures have already been verified.
|
|
if type(msg_payloads) is dict:
|
|
msg_payloads = [msg_payloads]
|
|
new_nodes = {}
|
|
for msg_payload in msg_payloads:
|
|
try:
|
|
node_info, node_addresses = NodeInfo.from_msg(msg_payload)
|
|
except IncompatibleOrInsaneFeatures:
|
|
continue
|
|
node_id = node_info.node_id
|
|
# Ignore node if it has no associated channel (DoS protection)
|
|
if node_id not in self._channels_for_node:
|
|
#self.logger.info('ignoring orphan node_announcement')
|
|
continue
|
|
node = self._nodes.get(node_id)
|
|
if node and node.timestamp >= node_info.timestamp:
|
|
continue
|
|
node = new_nodes.get(node_id)
|
|
if node and node.timestamp >= node_info.timestamp:
|
|
continue
|
|
# save
|
|
with self.lock:
|
|
self._nodes[node_id] = node_info
|
|
if 'raw' in msg_payload:
|
|
self._db_save_node_info(node_id, msg_payload['raw'])
|
|
with self.lock:
|
|
for addr in node_addresses:
|
|
net_addr = NetAddress(addr.host, addr.port)
|
|
self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0
|
|
self._db_save_node_addresses(node_addresses)
|
|
|
|
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
|
|
self.update_counts()
|
|
|
|
def get_old_policies(self, delta) -> Sequence[Tuple[bytes, ShortChannelID]]:
|
|
with self.lock:
|
|
_policies = self._policies.copy()
|
|
now = int(time.time())
|
|
return list(k for k, v in _policies.items() if v.timestamp <= now - delta)
|
|
|
|
def prune_old_policies(self, delta):
|
|
old_policies = self.get_old_policies(delta)
|
|
if old_policies:
|
|
for key in old_policies:
|
|
node_id, scid = key
|
|
with self.lock:
|
|
self._policies.pop(key)
|
|
self._db_delete_policy(*key)
|
|
self._update_num_policies_for_chan(scid)
|
|
self.update_counts()
|
|
self.logger.info(f'Deleting {len(old_policies)} old policies')
|
|
|
|
def prune_orphaned_channels(self):
|
|
with self.lock:
|
|
orphaned_chans = self._chans_with_0_policies.copy()
|
|
if orphaned_chans:
|
|
for short_channel_id in orphaned_chans:
|
|
self.remove_channel(short_channel_id)
|
|
self.update_counts()
|
|
self.logger.info(f'Deleting {len(orphaned_chans)} orphaned channels')
|
|
|
|
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes) -> bool:
|
|
"""Returns True iff the channel update was successfully added and it was different than
|
|
what we had before (if any).
|
|
"""
|
|
if not verify_sig_for_channel_update(msg_payload, start_node_id):
|
|
return False # ignore
|
|
short_channel_id = ShortChannelID(msg_payload['short_channel_id'])
|
|
msg_payload['start_node'] = start_node_id
|
|
key = (start_node_id, short_channel_id)
|
|
prev_chanupd = self._channel_updates_for_private_channels.get(key)
|
|
if prev_chanupd == msg_payload:
|
|
return False
|
|
self._channel_updates_for_private_channels[key] = msg_payload
|
|
return True
|
|
|
|
def remove_channel(self, short_channel_id: ShortChannelID):
|
|
# FIXME what about rm-ing policies?
|
|
with self.lock:
|
|
channel_info = self._channels.pop(short_channel_id, None)
|
|
if channel_info:
|
|
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
|
|
self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
|
|
self._update_num_policies_for_chan(short_channel_id)
|
|
# delete from database
|
|
self._db_delete_channel(short_channel_id)
|
|
|
|
def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]:
|
|
"""Returns list of (host, port, timestamp)."""
|
|
addr_to_ts = self._addresses.get(node_id)
|
|
if not addr_to_ts:
|
|
return []
|
|
return [(str(net_addr.host), net_addr.port, ts)
|
|
for net_addr, ts in addr_to_ts.items()]
|
|
|
|
@sql
|
|
@profiler
|
|
def load_data(self):
|
|
if self.data_loaded.is_set():
|
|
return
|
|
# Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow.
|
|
c = self.conn.cursor()
|
|
c.execute("""SELECT * FROM address""")
|
|
for x in c:
|
|
node_id, host, port, timestamp = x
|
|
try:
|
|
net_addr = NetAddress(host, port)
|
|
except Exception:
|
|
continue
|
|
self._addresses[node_id][net_addr] = int(timestamp or 0)
|
|
def newest_ts_for_node_id(node_id):
|
|
newest_ts = 0
|
|
for addr, ts in self._addresses[node_id].items():
|
|
newest_ts = max(newest_ts, ts)
|
|
return newest_ts
|
|
sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True)
|
|
self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
|
|
c.execute("""SELECT * FROM channel_info""")
|
|
for short_channel_id, msg in c:
|
|
try:
|
|
ci = ChannelInfo.from_raw_msg(msg)
|
|
except IncompatibleOrInsaneFeatures:
|
|
continue
|
|
self._channels[ShortChannelID.normalize(short_channel_id)] = ci
|
|
c.execute("""SELECT * FROM node_info""")
|
|
for node_id, msg in c:
|
|
try:
|
|
node_info, node_addresses = NodeInfo.from_raw_msg(msg)
|
|
except IncompatibleOrInsaneFeatures:
|
|
continue
|
|
# don't load node_addresses because they dont have timestamps
|
|
self._nodes[node_id] = node_info
|
|
c.execute("""SELECT * FROM policy""")
|
|
for key, msg in c:
|
|
p = Policy.from_raw_msg(key, msg)
|
|
self._policies[(p.start_node, p.short_channel_id)] = p
|
|
for channel_info in self._channels.values():
|
|
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
|
|
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
|
|
self._update_num_policies_for_chan(channel_info.short_channel_id)
|
|
self.logger.info(f'data loaded. {len(self._channels)} chans. {len(self._policies)} policies. '
|
|
f'{len(self._channels_for_node)} nodes.')
|
|
self.update_counts()
|
|
(nchans_with_0p, nchans_with_1p, nchans_with_2p) = self.get_num_channels_partitioned_by_policy_count()
|
|
self.logger.info(f'num_channels_partitioned_by_policy_count. '
|
|
f'0p: {nchans_with_0p}, 1p: {nchans_with_1p}, 2p: {nchans_with_2p}')
|
|
self.asyncio_loop.call_soon_threadsafe(self.data_loaded.set)
|
|
util.trigger_callback('gossip_db_loaded')
|
|
|
|
def _update_num_policies_for_chan(self, short_channel_id: ShortChannelID) -> None:
|
|
channel_info = self.get_channel_info(short_channel_id)
|
|
if channel_info is None:
|
|
with self.lock:
|
|
self._chans_with_0_policies.discard(short_channel_id)
|
|
self._chans_with_1_policies.discard(short_channel_id)
|
|
self._chans_with_2_policies.discard(short_channel_id)
|
|
return
|
|
p1 = self.get_policy_for_node(short_channel_id, channel_info.node1_id)
|
|
p2 = self.get_policy_for_node(short_channel_id, channel_info.node2_id)
|
|
with self.lock:
|
|
self._chans_with_0_policies.discard(short_channel_id)
|
|
self._chans_with_1_policies.discard(short_channel_id)
|
|
self._chans_with_2_policies.discard(short_channel_id)
|
|
if p1 is not None and p2 is not None:
|
|
self._chans_with_2_policies.add(short_channel_id)
|
|
elif p1 is None and p2 is None:
|
|
self._chans_with_0_policies.add(short_channel_id)
|
|
else:
|
|
self._chans_with_1_policies.add(short_channel_id)
|
|
|
|
def get_num_channels_partitioned_by_policy_count(self) -> Tuple[int, int, int]:
|
|
nchans_with_0p = len(self._chans_with_0_policies)
|
|
nchans_with_1p = len(self._chans_with_1_policies)
|
|
nchans_with_2p = len(self._chans_with_2_policies)
|
|
return nchans_with_0p, nchans_with_1p, nchans_with_2p
|
|
|
|
def get_policy_for_node(
|
|
self,
|
|
short_channel_id: bytes,
|
|
node_id: bytes,
|
|
*,
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
|
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
|
|
) -> Optional['Policy']:
|
|
channel_info = self.get_channel_info(short_channel_id)
|
|
if channel_info is not None: # publicly announced channel
|
|
policy = self._policies.get((node_id, short_channel_id))
|
|
if policy:
|
|
return policy
|
|
else: # private channel
|
|
chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id))
|
|
if chan_upd_dict:
|
|
return Policy.from_msg(chan_upd_dict)
|
|
# check if it's one of our own channels
|
|
if my_channels:
|
|
policy = get_mychannel_policy(short_channel_id, node_id, my_channels)
|
|
if policy:
|
|
return policy
|
|
if private_route_edges:
|
|
route_edge = private_route_edges.get(short_channel_id, None)
|
|
if route_edge:
|
|
return Policy.from_route_edge(route_edge)
|
|
|
|
def get_channel_info(
|
|
self,
|
|
short_channel_id: ShortChannelID,
|
|
*,
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
|
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
|
|
) -> Optional[ChannelInfo]:
|
|
ret = self._channels.get(short_channel_id)
|
|
if ret:
|
|
return ret
|
|
# check if it's one of our own channels
|
|
if my_channels:
|
|
channel_info = get_mychannel_info(short_channel_id, my_channels)
|
|
if channel_info:
|
|
return channel_info
|
|
if private_route_edges:
|
|
route_edge = private_route_edges.get(short_channel_id)
|
|
if route_edge:
|
|
return ChannelInfo.from_route_edge(route_edge)
|
|
|
|
def get_channels_for_node(
|
|
self,
|
|
node_id: bytes,
|
|
*,
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
|
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
|
|
) -> Set[ShortChannelID]:
|
|
"""Returns the set of short channel IDs where node_id is one of the channel participants."""
|
|
if not self.data_loaded.is_set():
|
|
raise Exception("channelDB data not loaded yet!")
|
|
relevant_channels = self._channels_for_node.get(node_id) or set()
|
|
relevant_channels = set(relevant_channels) # copy
|
|
# add our own channels # TODO maybe slow?
|
|
if my_channels:
|
|
for chan in my_channels.values():
|
|
if node_id in (chan.node_id, chan.get_local_pubkey()):
|
|
relevant_channels.add(chan.short_channel_id)
|
|
# add private channels # TODO maybe slow?
|
|
if private_route_edges:
|
|
for route_edge in private_route_edges.values():
|
|
if node_id in (route_edge.start_node, route_edge.end_node):
|
|
relevant_channels.add(route_edge.short_channel_id)
|
|
return relevant_channels
|
|
|
|
def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *,
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[Tuple[bytes, bytes]]:
|
|
channel_info = self.get_channel_info(short_channel_id)
|
|
if channel_info is not None: # publicly announced channel
|
|
return channel_info.node1_id, channel_info.node2_id
|
|
# check if it's one of our own channels
|
|
if not my_channels:
|
|
return
|
|
chan = my_channels.get(short_channel_id) # type: Optional[Channel]
|
|
if not chan:
|
|
return
|
|
return chan.get_local_pubkey(), chan.node_id
|
|
|
|
def get_node_info_for_node_id(self, node_id: bytes) -> Optional['NodeInfo']:
|
|
return self._nodes.get(node_id)
|
|
|
|
def get_node_infos(self) -> Dict[bytes, NodeInfo]:
|
|
with self.lock:
|
|
return self._nodes.copy()
|
|
|
|
def get_node_policies(self) -> Dict[Tuple[bytes, ShortChannelID], Policy]:
|
|
with self.lock:
|
|
return self._policies.copy()
|
|
|
|
def get_node_by_prefix(self, prefix):
|
|
with self.lock:
|
|
for k in self._addresses.keys():
|
|
if k.startswith(prefix):
|
|
return k
|
|
raise Exception('node not found')
|
|
|
|
def to_dict(self) -> dict:
|
|
""" Generates a graph representation in terms of a dictionary.
|
|
|
|
The dictionary contains only native python types and can be encoded
|
|
to json.
|
|
"""
|
|
with self.lock:
|
|
graph = {'nodes': [], 'channels': []}
|
|
|
|
# gather nodes
|
|
for pk, nodeinfo in self._nodes.items():
|
|
# use _asdict() to convert NamedTuples to json encodable dicts
|
|
graph['nodes'].append(
|
|
nodeinfo._asdict(),
|
|
)
|
|
graph['nodes'][-1]['addresses'] = [
|
|
{'host': str(addr.host), 'port': addr.port, 'timestamp': ts}
|
|
for addr, ts in self._addresses[pk].items()
|
|
]
|
|
|
|
# gather channels
|
|
for cid, channelinfo in self._channels.items():
|
|
graph['channels'].append(
|
|
channelinfo._asdict(),
|
|
)
|
|
policy1 = self._policies.get(
|
|
(channelinfo.node1_id, channelinfo.short_channel_id))
|
|
policy2 = self._policies.get(
|
|
(channelinfo.node2_id, channelinfo.short_channel_id))
|
|
graph['channels'][-1]['policy1'] = policy1._asdict() if policy1 else None
|
|
graph['channels'][-1]['policy2'] = policy2._asdict() if policy2 else None
|
|
|
|
# need to use json_normalize otherwise json encoding in rpc server fails
|
|
graph = json_normalize(graph)
|
|
return graph
|