# -*- coding: utf-8 -*- # # Electrum - lightweight Bitcoin client # Copyright (C) 2018 The Electrum developers # # Permission is hereby granted, free of charge, to any person # obtaining a copy of this software and associated documentation files # (the "Software"), to deal in the Software without restriction, # including without limitation the rights to use, copy, modify, merge, # publish, distribute, sublicense, and/or sell copies of the Software, # and to permit persons to whom the Software is furnished to do so, # subject to the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import ipaddress import time import random import os from collections import defaultdict from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set import binascii import base64 import asyncio import threading from enum import IntEnum import functools from aiorpcx import NetAddress from electrum_ecc import ECPubkey from .sql_db import SqlDB, sql from . import constants, util from .util import profiler, get_headers_dir, is_ip_address, json_normalize, UserFacingException, is_private_netaddress from .lntransport import LNPeerAddr from .lnutil import (ShortChannelID, validate_features, IncompatibleOrInsaneFeatures, InvalidGossipMsg, GossipForwardingMessage, GossipTimestampFilter) from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update from .lnmsg import decode_msg from .crypto import sha256d from .lnmsg import FailedToParseMsg if TYPE_CHECKING: from .network import Network from .lnchannel import Channel from .lnrouter import RouteEdge from .simple_config import SimpleConfig FLAG_DISABLE = 1 << 1 FLAG_DIRECTION = 1 << 0 class ChannelDBNotLoaded(UserFacingException): pass class ChannelInfo(NamedTuple): short_channel_id: ShortChannelID node1_id: bytes node2_id: bytes capacity_sat: Optional[int] raw: Optional[bytes] = None @staticmethod def from_msg(payload: dict) -> 'ChannelInfo': features = int.from_bytes(payload['features'], 'big') features = validate_features(features) channel_id = payload['short_channel_id'] node_id_1 = payload['node_id_1'] node_id_2 = payload['node_id_2'] assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] capacity_sat = None return ChannelInfo( short_channel_id = ShortChannelID.normalize(channel_id), node1_id = node_id_1, node2_id = node_id_2, capacity_sat = capacity_sat, raw = payload.get('raw') ) @staticmethod def from_raw_msg(raw: bytes) -> 'ChannelInfo': payload_dict = decode_msg(raw)[1] payload_dict['raw'] = raw return ChannelInfo.from_msg(payload_dict) @staticmethod def from_route_edge(route_edge: 'RouteEdge') -> 'ChannelInfo': node1_id, node2_id = sorted([route_edge.start_node, route_edge.end_node]) return ChannelInfo( short_channel_id=route_edge.short_channel_id, node1_id=node1_id, node2_id=node2_id, capacity_sat=None, ) class Policy(NamedTuple): key: bytes cltv_delta: int htlc_minimum_msat: int htlc_maximum_msat: Optional[int] fee_base_msat: int fee_proportional_millionths: int channel_flags: int message_flags: int timestamp: int raw: Optional[bytes] = None @staticmethod def from_msg(payload: dict) -> 'Policy': return Policy( key = payload['short_channel_id'] + payload['start_node'], cltv_delta = payload['cltv_expiry_delta'], htlc_minimum_msat = payload['htlc_minimum_msat'], htlc_maximum_msat = payload.get('htlc_maximum_msat', None), fee_base_msat = payload['fee_base_msat'], fee_proportional_millionths = payload['fee_proportional_millionths'], message_flags = int.from_bytes(payload['message_flags'], "big"), channel_flags = int.from_bytes(payload['channel_flags'], "big"), timestamp = payload['timestamp'], raw = payload.get('raw'), ) @staticmethod def from_raw_msg(key: bytes, raw: bytes) -> 'Policy': payload = decode_msg(raw)[1] payload['start_node'] = key[8:] payload['raw'] = raw return Policy.from_msg(payload) @staticmethod def from_route_edge(route_edge: 'RouteEdge') -> 'Policy': return Policy( key=route_edge.short_channel_id + route_edge.start_node, cltv_delta=route_edge.cltv_delta, htlc_minimum_msat=0, htlc_maximum_msat=None, fee_base_msat=route_edge.fee_base_msat, fee_proportional_millionths=route_edge.fee_proportional_millionths, channel_flags=0, message_flags=0, timestamp=0, ) def is_disabled(self): return self.channel_flags & FLAG_DISABLE @property def short_channel_id(self) -> ShortChannelID: return ShortChannelID.normalize(self.key[0:8]) @property def start_node(self) -> bytes: return self.key[8:] class NodeInfo(NamedTuple): node_id: bytes features: int timestamp: int alias: str raw: Optional[bytes] @staticmethod def from_msg(payload) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]: node_id = payload['node_id'] features = int.from_bytes(payload['features'], "big") features = validate_features(features) addresses = NodeInfo.parse_addresses_field(payload['addresses']) peer_addrs = [] for host, port in addresses: try: peer_addrs.append(LNPeerAddr(host=host, port=port, pubkey=node_id)) except ValueError: pass alias = payload['alias'].rstrip(b'\x00') try: alias = alias.decode('utf8') except Exception: alias = '' timestamp = payload['timestamp'] node_info = NodeInfo( node_id=node_id, features=features, timestamp=timestamp, alias=alias, raw=payload.get('raw')) return node_info, peer_addrs @staticmethod def from_raw_msg(raw: bytes) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]: payload_dict = decode_msg(raw)[1] payload_dict['raw'] = raw return NodeInfo.from_msg(payload_dict) @staticmethod def to_addresses_field(hostname: str, port: int) -> bytes: """Encodes a hostname/port pair into a BOLT-7 'addresses' field.""" if (NodeInfo.invalid_announcement_hostname(hostname) or port is None or port <= 0 or port > 65535): return b'' port_bytes = port.to_bytes(2, 'big') if is_ip_address(hostname): # ipv4 or ipv6 ip_addr = ipaddress.ip_address(hostname) if ip_addr.version == 4: return b'\x01' + ip_addr.packed + port_bytes elif ip_addr.version == 6: return b'\x02' + ip_addr.packed + port_bytes return b'' elif hostname.endswith('.onion'): # Tor onion v3 onion_addr: bytes = base64.b32decode(hostname[:-6], casefold=True) return b'\x04' + onion_addr + port_bytes else: try: hostname_ascii: bytes = hostname.encode('ascii') except UnicodeEncodeError: # encoding single characters to punycode (according to spec) doesn't make sense # as you can't differentiate them from regular ascii? encoding the whole string to punycode # doesn't work either as the receiver would interpret it as regular ascii. # hostname_ascii: bytes = hostname.encode('punycode') return b'' if len(hostname_ascii) + 3 > 258: # + 1 byte for length and 2 for port return b'' # too long return b'\x05' + len(hostname_ascii).to_bytes(1, "big") + hostname_ascii + port_bytes @staticmethod def invalid_announcement_hostname(hostname: Optional[str]) -> bool: """Returns True if hostname unsuited for publishing in a NodeAnnouncement.""" if (hostname is None or hostname == "" or is_private_netaddress(hostname) or hostname.startswith("http://") # not catching 'http' due to onion addresses or hostname.startswith("https://")): return True if hostname.endswith('.onion'): if len(hostname) != 62: # not an onion v3 link (probably onion v2) return True return False @staticmethod def parse_addresses_field(addresses_field): buf = addresses_field def read(n): nonlocal buf data, buf = buf[0:n], buf[n:] return data addresses = [] while buf: atype = ord(read(1)) if atype == 0: pass elif atype == 1: # IPv4 ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4))) port = int.from_bytes(read(2), 'big') if is_ip_address(ipv4_addr) and port != 0: addresses.append((ipv4_addr, port)) elif atype == 2: # IPv6 ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)]) ipv6_addr = ipv6_addr.decode('ascii') port = int.from_bytes(read(2), 'big') if is_ip_address(ipv6_addr) and port != 0: addresses.append((ipv6_addr, port)) elif atype == 3: # onion v2 read(12) # we skip onion v2 as it is deprecated elif atype == 4: # onion v3 host = base64.b32encode(read(35)) + b'.onion' host = host.decode('ascii').lower() port = int.from_bytes(read(2), 'big') addresses.append((host, port)) elif atype == 5: # dns hostname len_hostname = int.from_bytes(read(1), 'big') host = read(len_hostname).decode('ascii') port = int.from_bytes(read(2), 'big') if not NodeInfo.invalid_announcement_hostname(host) and port > 0: addresses.append((host, port)) else: # unknown address type # we don't know how long it is -> have to escape # if there are other addresses we could have parsed later, they are lost. break return addresses class UpdateStatus(IntEnum): ORPHANED = 0 EXPIRED = 1 DEPRECATED = 2 UNCHANGED = 3 GOOD = 4 class CategorizedChannelUpdates(NamedTuple): orphaned: List # no channel announcement for channel update expired: List # update older than two weeks deprecated: List # update older than database entry unchanged: List # unchanged policies good: List # good updates def get_mychannel_info(short_channel_id: ShortChannelID, my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]: chan = my_channels.get(short_channel_id) if not chan: return raw_msg, _ = chan.construct_channel_announcement_without_sigs() ci = ChannelInfo.from_raw_msg(raw_msg) return ci._replace(capacity_sat=chan.constraints.capacity) def get_mychannel_policy(short_channel_id: bytes, node_id: bytes, my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[Policy]: chan = my_channels.get(short_channel_id) # type: Optional[Channel] if not chan: return if node_id == chan.node_id: # incoming direction (to us) remote_update_raw = chan.get_remote_update() if not remote_update_raw: return now = int(time.time()) remote_update_decoded = decode_msg(remote_update_raw)[1] remote_update_decoded['timestamp'] = now remote_update_decoded['start_node'] = node_id return Policy.from_msg(remote_update_decoded) elif node_id == chan.get_local_pubkey(): # outgoing direction (from us) local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1] local_update_decoded['start_node'] = node_id return Policy.from_msg(local_update_decoded) class _LoadDataAborted(Exception): pass create_channel_info = """ CREATE TABLE IF NOT EXISTS channel_info ( short_channel_id BLOB(8), msg BLOB, PRIMARY KEY(short_channel_id) )""" create_policy = """ CREATE TABLE IF NOT EXISTS policy ( key BLOB(41), msg BLOB, PRIMARY KEY(key) )""" create_address = """ CREATE TABLE IF NOT EXISTS address ( node_id BLOB(33), host STRING(256), port INTEGER NOT NULL, timestamp INTEGER, PRIMARY KEY(node_id, host, port) )""" create_node_info = """ CREATE TABLE IF NOT EXISTS node_info ( node_id BLOB(33), msg BLOB, PRIMARY KEY(node_id) )""" class ChannelDB(SqlDB): NUM_MAX_RECENT_PEERS = 20 PRIVATE_CHAN_UPD_CACHE_TTL_NORMAL = 600 PRIVATE_CHAN_UPD_CACHE_TTL_SHORT = 120 def __init__(self, network: 'Network'): path = self.get_file_path(network.config) super().__init__(network.asyncio_loop, path, commit_interval=100) self.lock = threading.RLock() self.num_nodes = 0 self.num_channels = 0 self.num_policies = 0 self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], Tuple[dict, int]] # note: ^ we could maybe move this cache into PaySession instead of being global. # That would only make sense though if PaySessions were never too short # (e.g. consider trampoline forwarding). self.ca_verifier = LNChannelVerifier(network, self) # initialized in load_data # note: modify/iterate needs self.lock self._channels = {} # type: Dict[ShortChannelID, ChannelInfo] self._policies = {} # type: Dict[Tuple[bytes, ShortChannelID], Policy] # (node_id, scid) -> Policy self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo # node_id -> NetAddress -> timestamp self._addresses = defaultdict(dict) # type: Dict[bytes, Dict[NetAddress, int]] self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]] self._recent_peers = [] # type: List[bytes] # list of node_ids self._chans_with_0_policies = set() # type: Set[ShortChannelID] self._chans_with_1_policies = set() # type: Set[ShortChannelID] self._chans_with_2_policies = set() # type: Set[ShortChannelID] self.forwarding_lock = threading.RLock() self.fwd_channels = [] # type: List[GossipForwardingMessage] self.fwd_orphan_channels = [] # type: List[GossipForwardingMessage] self.fwd_channel_updates = [] # type: List[GossipForwardingMessage] self.fwd_node_announcements = [] # type: List[GossipForwardingMessage] self.data_loaded = asyncio.Event() self.network = network # only for callback @classmethod def get_file_path(cls, config: 'SimpleConfig') -> str: return os.path.join(get_headers_dir(config), 'gossip_db') def update_counts(self): self.num_nodes = len(self._nodes) self.num_channels = len(self._channels) self.num_policies = len(self._policies) util.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies) util.trigger_callback('ln_gossip_sync_progress') def get_channel_ids(self): with self.lock: return set(self._channels.keys()) def add_recent_peer(self, peer: LNPeerAddr): now = int(time.time()) node_id = peer.pubkey with self.lock: self._addresses[node_id][peer.net_addr()] = now # list is ordered if node_id in self._recent_peers: self._recent_peers.remove(node_id) self._recent_peers.insert(0, node_id) self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS] self._db_save_node_address(peer, now) def get_200_randomly_sorted_nodes_not_in(self, node_ids): with self.lock: unshuffled = set(self._nodes.keys()) - node_ids return random.sample(list(unshuffled), min(200, len(unshuffled))) def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]: """Returns latest address we successfully connected to, for given node.""" addr_to_ts = self._addresses.get(node_id) if not addr_to_ts: return None addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0] try: return LNPeerAddr(str(addr.host), addr.port, node_id) except ValueError: return None def get_recent_peers(self): if not self.data_loaded.is_set(): raise ChannelDBNotLoaded("channelDB data not loaded yet!") with self.lock: ret = [self.get_last_good_address(node_id) for node_id in self._recent_peers] return ret # note: currently channel announcements are trusted by default (trusted=True); # they are not SPV-verified. Verifying them would make the gossip sync # even slower; especially as servers will start throttling us. # It would probably put significant strain on servers if all clients # verified the complete gossip. def add_channel_announcements(self, msg_payloads, *, trusted=True): # note: signatures have already been verified. if type(msg_payloads) is dict: msg_payloads = [msg_payloads] added = 0 for msg in msg_payloads: short_channel_id = ShortChannelID(msg['short_channel_id']) if short_channel_id in self._channels: continue if constants.net.rev_genesis_bytes() != msg['chain_hash']: self.logger.info("ChanAnn has unexpected chain_hash {}".format(msg['chain_hash'].hex())) continue try: channel_info = ChannelInfo.from_msg(msg) except IncompatibleOrInsaneFeatures as e: self.logger.info(f"unknown or insane feature bits: {e!r}") continue if trusted: added += 1 self.add_verified_channel_info(msg) else: added += self.ca_verifier.add_new_channel_info(short_channel_id, msg) self.update_counts() def add_verified_channel_info(self, msg: dict, *, capacity_sat: int = None) -> None: try: channel_info = ChannelInfo.from_msg(msg) except IncompatibleOrInsaneFeatures: return channel_info = channel_info._replace(capacity_sat=capacity_sat) with self.lock: self._channels[channel_info.short_channel_id] = channel_info self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) self._update_num_policies_for_chan(channel_info.short_channel_id) if 'raw' in msg: self._db_save_channel(channel_info.short_channel_id, msg['raw']) with self.forwarding_lock: if fwd_msg := GossipForwardingMessage.from_payload(msg): self.fwd_channels.append(fwd_msg) def policy_changed(self, old_policy: Policy, new_policy: Policy, verbose: bool) -> bool: changed = False if old_policy.cltv_delta != new_policy.cltv_delta: changed |= True if verbose: self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_delta} -> {new_policy.cltv_delta}') if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat: changed |= True if verbose: self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}') if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat: changed |= True if verbose: self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}') if old_policy.fee_base_msat != new_policy.fee_base_msat: changed |= True if verbose: self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}') if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths: changed |= True if verbose: self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}') if old_policy.channel_flags != new_policy.channel_flags: changed |= True if verbose: self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}') if old_policy.message_flags != new_policy.message_flags: changed |= True if verbose: self.logger.info(f'message_flags: {old_policy.message_flags} -> {new_policy.message_flags}') if not changed and verbose: self.logger.info(f'policy unchanged: {old_policy.timestamp} -> {new_policy.timestamp}') return changed def add_channel_update( self, payload, *, max_age=None, verify=True, verbose=True) -> UpdateStatus: now = int(time.time()) short_channel_id = ShortChannelID(payload['short_channel_id']) timestamp = payload['timestamp'] if max_age and now - timestamp > max_age: return UpdateStatus.EXPIRED if timestamp - now > 60: return UpdateStatus.DEPRECATED channel_info = self._channels.get(short_channel_id) if not channel_info: return UpdateStatus.ORPHANED flags = int.from_bytes(payload['channel_flags'], 'big') direction = flags & FLAG_DIRECTION start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id payload['start_node'] = start_node # compare updates to existing database entries short_channel_id = ShortChannelID(payload['short_channel_id']) key = (start_node, short_channel_id) old_policy = self._policies.get(key) if old_policy and timestamp <= old_policy.timestamp + 60: return UpdateStatus.DEPRECATED if verify: self.verify_channel_update(payload) policy = Policy.from_msg(payload) with self.lock: self._policies[key] = policy self._update_num_policies_for_chan(short_channel_id) if 'raw' in payload: self._db_save_policy(policy.key, payload['raw']) if old_policy and not self.policy_changed(old_policy, policy, verbose): return UpdateStatus.UNCHANGED else: if policy.message_flags & 0b10 == 0: # check if its `dont_forward` with self.forwarding_lock: if fwd_msg := GossipForwardingMessage.from_payload(payload): self.fwd_channel_updates.append(fwd_msg) return UpdateStatus.GOOD def add_channel_updates(self, payloads, max_age=None) -> CategorizedChannelUpdates: orphaned = [] expired = [] deprecated = [] unchanged = [] good = [] for payload in payloads: r = self.add_channel_update(payload, max_age=max_age, verbose=False, verify=True) if r == UpdateStatus.ORPHANED: orphaned.append(payload) elif r == UpdateStatus.EXPIRED: expired.append(payload) elif r == UpdateStatus.DEPRECATED: deprecated.append(payload) elif r == UpdateStatus.UNCHANGED: unchanged.append(payload) elif r == UpdateStatus.GOOD: good.append(payload) self.update_counts() return CategorizedChannelUpdates( orphaned=orphaned, expired=expired, deprecated=deprecated, unchanged=unchanged, good=good) def create_database(self): c = self.conn.cursor() c.execute(create_node_info) c.execute(create_address) c.execute(create_policy) c.execute(create_channel_info) self.conn.commit() @sql def _db_save_policy(self, key: bytes, msg: bytes): # 'msg' is a 'channel_update' message c = self.conn.cursor() c.execute("""REPLACE INTO policy (key, msg) VALUES (?,?)""", [key, msg]) @sql def _db_delete_policy(self, node_id: bytes, short_channel_id: ShortChannelID): key = short_channel_id + node_id c = self.conn.cursor() c.execute("""DELETE FROM policy WHERE key=?""", (key,)) @sql def _db_save_channel(self, short_channel_id: ShortChannelID, msg: bytes): # 'msg' is a 'channel_announcement' message c = self.conn.cursor() c.execute("REPLACE INTO channel_info (short_channel_id, msg) VALUES (?,?)", [short_channel_id, msg]) @sql def _db_delete_channel(self, short_channel_id: ShortChannelID): c = self.conn.cursor() c.execute("""DELETE FROM channel_info WHERE short_channel_id=?""", (short_channel_id,)) @sql def _db_save_node_info(self, node_id: bytes, msg: bytes): # 'msg' is a 'node_announcement' message c = self.conn.cursor() c.execute("REPLACE INTO node_info (node_id, msg) VALUES (?,?)", [node_id, msg]) @sql def _db_save_node_address(self, peer: LNPeerAddr, timestamp: int): c = self.conn.cursor() c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (peer.pubkey, peer.host, peer.port, timestamp)) @sql def _db_save_node_addresses(self, node_addresses: Sequence[LNPeerAddr]): c = self.conn.cursor() for addr in node_addresses: c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.pubkey, addr.host, addr.port)) r = c.fetchall() if r == []: c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.pubkey, addr.host, addr.port, 0)) @classmethod def verify_channel_update(cls, payload, *, start_node: bytes = None) -> None: short_channel_id = payload['short_channel_id'] short_channel_id = ShortChannelID(short_channel_id) if constants.net.rev_genesis_bytes() != payload['chain_hash']: raise InvalidGossipMsg('wrong chain hash') start_node = payload.get('start_node', None) or start_node assert start_node is not None if not verify_sig_for_channel_update(payload, start_node): raise InvalidGossipMsg(f'failed verifying channel update for {short_channel_id}') @classmethod def verify_channel_announcement(cls, payload) -> None: h = sha256d(payload['raw'][2+256:]) pubkeys = [payload['node_id_1'], payload['node_id_2'], payload['bitcoin_key_1'], payload['bitcoin_key_2']] sigs = [payload['node_signature_1'], payload['node_signature_2'], payload['bitcoin_signature_1'], payload['bitcoin_signature_2']] for pubkey, sig in zip(pubkeys, sigs): if not ECPubkey(pubkey).ecdsa_verify(sig, h): raise InvalidGossipMsg('signature failed') @classmethod def verify_node_announcement(cls, payload) -> None: pubkey = payload['node_id'] signature = payload['signature'] h = sha256d(payload['raw'][66:]) if not ECPubkey(pubkey).ecdsa_verify(signature, h): raise InvalidGossipMsg('signature failed') def add_node_announcements(self, msg_payloads): # note: signatures have already been verified. if type(msg_payloads) is dict: msg_payloads = [msg_payloads] new_nodes = set() # type: Set[bytes] for msg_payload in msg_payloads: try: node_info, node_addresses = NodeInfo.from_msg(msg_payload) except IncompatibleOrInsaneFeatures: continue node_id = node_info.node_id # Ignore node if it has no associated channel (DoS protection) if node_id not in self._channels_for_node: #self.logger.info('ignoring orphan node_announcement') continue node = self._nodes.get(node_id) if node and node.timestamp >= node_info.timestamp: continue new_nodes.add(node_id) # save with self.lock: self._nodes[node_id] = node_info if 'raw' in msg_payload: self._db_save_node_info(node_id, msg_payload['raw']) with self.lock: for addr in node_addresses: net_addr = NetAddress(addr.host, addr.port) self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0 self._db_save_node_addresses(node_addresses) with self.forwarding_lock: if fwd_msg := GossipForwardingMessage.from_payload(msg_payload): self.fwd_node_announcements.append(fwd_msg) self.update_counts() def get_old_policies(self, delta) -> Sequence[Tuple[bytes, ShortChannelID]]: with self.lock: _policies = self._policies.copy() now = int(time.time()) return list(k for k, v in _policies.items() if v.timestamp <= now - delta) def prune_old_policies(self, delta): old_policies = self.get_old_policies(delta) if old_policies: for key in old_policies: node_id, scid = key with self.lock: self._policies.pop(key) self._db_delete_policy(*key) self._update_num_policies_for_chan(scid) self.update_counts() self.logger.info(f'Deleting {len(old_policies)} old policies') def prune_orphaned_channels(self): with self.lock: orphaned_chans = self._chans_with_0_policies.copy() if orphaned_chans: for short_channel_id in orphaned_chans: self.remove_channel(short_channel_id) self.update_counts() self.logger.info(f'Deleting {len(orphaned_chans)} orphaned channels') def _get_channel_update_for_private_channel( self, start_node_id: bytes, short_channel_id: ShortChannelID, *, now: int = None, # unix ts ) -> Optional[dict]: if now is None: now = int(time.time()) key = (start_node_id, short_channel_id) chan_upd_dict, cache_expiration = self._channel_updates_for_private_channels.get(key, (None, 0)) if cache_expiration < now: chan_upd_dict = None # already expired # TODO rm expired entries from cache (note: perf vs thread-safety) return chan_upd_dict def add_channel_update_for_private_channel( self, msg_payload: dict, start_node_id: bytes, *, cache_ttl: int = None, # seconds ) -> bool: """Returns True iff the channel update was successfully added and it was different than what we had before (if any). """ if not verify_sig_for_channel_update(msg_payload, start_node_id): return False # ignore now = int(time.time()) short_channel_id = ShortChannelID(msg_payload['short_channel_id']) msg_payload['start_node'] = start_node_id prev_chanupd = self._get_channel_update_for_private_channel(start_node_id, short_channel_id, now=now) if prev_chanupd == msg_payload: return False if cache_ttl is None: cache_ttl = self.PRIVATE_CHAN_UPD_CACHE_TTL_NORMAL cache_expiration = now + cache_ttl key = (start_node_id, short_channel_id) with self.lock: self._channel_updates_for_private_channels[key] = msg_payload, cache_expiration return True def remove_channel(self, short_channel_id: ShortChannelID): # FIXME what about rm-ing policies? with self.lock: channel_info = self._channels.pop(short_channel_id, None) if channel_info: self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id) self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id) self._update_num_policies_for_chan(short_channel_id) # delete from database self._db_delete_channel(short_channel_id) def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]: """Returns list of (host, port, timestamp).""" addr_to_ts = self._addresses.get(node_id) if not addr_to_ts: return [] return [(str(net_addr.host), net_addr.port, ts) for net_addr, ts in addr_to_ts.items()] def handle_abort(func): @functools.wraps(func) def wrapper(self: 'ChannelDB', *args, **kwargs): try: return func(self, *args, **kwargs) except _LoadDataAborted: return return wrapper @sql @profiler @handle_abort def load_data(self): if self.data_loaded.is_set(): return # Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow. def maybe_abort(): if self.stopping: self.logger.info("load_data() was asked to stop. exiting early.") raise _LoadDataAborted() c = self.conn.cursor() c.execute("""SELECT * FROM address""") for x in c: maybe_abort() node_id, host, port, timestamp = x try: net_addr = NetAddress(host, port) except Exception: continue self._addresses[node_id][net_addr] = int(timestamp or 0) def newest_ts_for_node_id(node_id): newest_ts = 0 for addr, ts in self._addresses[node_id].items(): newest_ts = max(newest_ts, ts) return newest_ts sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True) self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS] c.execute("""SELECT * FROM channel_info""") for short_channel_id, msg in c: maybe_abort() try: ci = ChannelInfo.from_raw_msg(msg) except IncompatibleOrInsaneFeatures: continue except FailedToParseMsg: continue self._channels[ShortChannelID.normalize(short_channel_id)] = ci c.execute("""SELECT * FROM node_info""") for node_id, msg in c: maybe_abort() try: node_info, node_addresses = NodeInfo.from_raw_msg(msg) except IncompatibleOrInsaneFeatures: continue except FailedToParseMsg: continue # don't load node_addresses because they dont have timestamps self._nodes[node_id] = node_info c.execute("""SELECT * FROM policy""") for key, msg in c: maybe_abort() try: p = Policy.from_raw_msg(key, msg) except FailedToParseMsg: continue self._policies[(p.start_node, p.short_channel_id)] = p for channel_info in self._channels.values(): self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) self._update_num_policies_for_chan(channel_info.short_channel_id) self.logger.info(f'data loaded. {len(self._channels)} chans. {len(self._policies)} policies. ' f'{len(self._channels_for_node)} nodes.') self.update_counts() (nchans_with_0p, nchans_with_1p, nchans_with_2p) = self.get_num_channels_partitioned_by_policy_count() self.logger.info(f'num_channels_partitioned_by_policy_count. ' f'0p: {nchans_with_0p}, 1p: {nchans_with_1p}, 2p: {nchans_with_2p}') self.asyncio_loop.call_soon_threadsafe(self.data_loaded.set) util.trigger_callback('gossip_db_loaded') def _update_num_policies_for_chan(self, short_channel_id: ShortChannelID) -> None: channel_info = self.get_channel_info(short_channel_id) if channel_info is None: with self.lock: self._chans_with_0_policies.discard(short_channel_id) self._chans_with_1_policies.discard(short_channel_id) self._chans_with_2_policies.discard(short_channel_id) return p1 = self.get_policy_for_node(short_channel_id, channel_info.node1_id) p2 = self.get_policy_for_node(short_channel_id, channel_info.node2_id) with self.lock: self._chans_with_0_policies.discard(short_channel_id) self._chans_with_1_policies.discard(short_channel_id) self._chans_with_2_policies.discard(short_channel_id) if p1 is not None and p2 is not None: self._chans_with_2_policies.add(short_channel_id) elif p1 is None and p2 is None: self._chans_with_0_policies.add(short_channel_id) else: self._chans_with_1_policies.add(short_channel_id) def get_num_channels_partitioned_by_policy_count(self) -> Tuple[int, int, int]: nchans_with_0p = len(self._chans_with_0_policies) nchans_with_1p = len(self._chans_with_1_policies) nchans_with_2p = len(self._chans_with_2_policies) return nchans_with_0p, nchans_with_1p, nchans_with_2p def get_policy_for_node( self, short_channel_id: ShortChannelID, node_id: bytes, *, my_channels: Dict[ShortChannelID, 'Channel'] = None, private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, now: int = None, # unix ts ) -> Optional['Policy']: channel_info = self.get_channel_info(short_channel_id) if channel_info is not None: # publicly announced channel policy = self._policies.get((node_id, short_channel_id)) if policy: return policy elif chan_upd_dict := self._get_channel_update_for_private_channel(node_id, short_channel_id, now=now): return Policy.from_msg(chan_upd_dict) # check if it's one of our own channels if my_channels: policy = get_mychannel_policy(short_channel_id, node_id, my_channels) if policy: return policy if private_route_edges: route_edge = private_route_edges.get(short_channel_id, None) if route_edge: return Policy.from_route_edge(route_edge) def get_channel_info( self, short_channel_id: ShortChannelID, *, my_channels: Dict[ShortChannelID, 'Channel'] = None, private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, ) -> Optional[ChannelInfo]: ret = self._channels.get(short_channel_id) if ret: return ret # check if it's one of our own channels if my_channels: channel_info = get_mychannel_info(short_channel_id, my_channels) if channel_info: return channel_info if private_route_edges: route_edge = private_route_edges.get(short_channel_id) if route_edge: return ChannelInfo.from_route_edge(route_edge) def get_channels_for_node( self, node_id: bytes, *, my_channels: Dict[ShortChannelID, 'Channel'] = None, private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, ) -> Set[ShortChannelID]: """Returns the set of short channel IDs where node_id is one of the channel participants.""" if not self.data_loaded.is_set(): raise ChannelDBNotLoaded("channelDB data not loaded yet!") relevant_channels = self._channels_for_node.get(node_id) or set() relevant_channels = set(relevant_channels) # copy # add our own channels # TODO maybe slow? if my_channels: for chan in my_channels.values(): if node_id in (chan.node_id, chan.get_local_pubkey()): relevant_channels.add(chan.short_channel_id) # add private channels # TODO maybe slow? if private_route_edges: for route_edge in private_route_edges.values(): if node_id in (route_edge.start_node, route_edge.end_node): relevant_channels.add(route_edge.short_channel_id) return relevant_channels def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *, my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[Tuple[bytes, bytes]]: channel_info = self.get_channel_info(short_channel_id) if channel_info is not None: # publicly announced channel return channel_info.node1_id, channel_info.node2_id # check if it's one of our own channels if not my_channels: return chan = my_channels.get(short_channel_id) # type: Optional[Channel] if not chan: return return chan.get_local_pubkey(), chan.node_id def get_node_info_for_node_id(self, node_id: bytes) -> Optional['NodeInfo']: return self._nodes.get(node_id) def get_node_infos(self) -> Dict[bytes, NodeInfo]: with self.lock: return self._nodes.copy() def get_node_policies(self) -> Dict[Tuple[bytes, ShortChannelID], Policy]: with self.lock: return self._policies.copy() def get_node_by_prefix(self, prefix): with self.lock: for k in self._addresses.keys(): if k.startswith(prefix): return k raise Exception('node not found') def clear_forwarding_gossip(self) -> None: with self.forwarding_lock: self.fwd_channels.clear() self.fwd_channel_updates.clear() self.fwd_node_announcements.clear() def filter_orphan_channel_anns( self, channel_anns: List[GossipForwardingMessage] ) -> Tuple[List, List]: """Check if the channel announcements we want to forward have at least 1 update""" to_forward_anns = [] orphaned_channel_anns = [] for channel in channel_anns: if channel.scid is None: continue elif (channel.scid in self._chans_with_1_policies or channel.scid in self._chans_with_2_policies): to_forward_anns.append(channel) continue orphaned_channel_anns.append(channel) return to_forward_anns, orphaned_channel_anns def set_fwd_channel_anns_ts(self, channel_anns: List[GossipForwardingMessage]) \ -> List[GossipForwardingMessage]: """Set the timestamps of the passed channel announcements from the corresponding policies""" timestamped_chan_anns: List[GossipForwardingMessage] = [] with self.lock: policies = self._policies.copy() channels = self._channels.copy() for chan_ann in channel_anns: if chan_ann.timestamp is not None: timestamped_chan_anns.append(chan_ann) continue scid = chan_ann.scid if (channel_info := channels.get(scid)) is None: continue policy1 = policies.get((channel_info.node1_id, scid)) policy2 = policies.get((channel_info.node2_id, scid)) potential_timestamps = [] for policy in [policy1, policy2]: if policy is not None: potential_timestamps.append(policy.timestamp) if not potential_timestamps: continue chan_ann.timestamp = min(potential_timestamps) timestamped_chan_anns.append(chan_ann) return timestamped_chan_anns def get_forwarding_gossip_batch(self) -> List[GossipForwardingMessage]: with self.forwarding_lock: fwd_gossip = self.fwd_channel_updates + self.fwd_node_announcements channel_anns = self.fwd_channels.copy() self.clear_forwarding_gossip() fwd_chan_anns1, _ = self.filter_orphan_channel_anns(self.fwd_orphan_channels) fwd_chan_anns2, self.fwd_orphan_channels = self.filter_orphan_channel_anns(channel_anns) channel_anns = self.set_fwd_channel_anns_ts(fwd_chan_anns1 + fwd_chan_anns2) return channel_anns + fwd_gossip def get_gossip_in_timespan(self, timespan: GossipTimestampFilter) \ -> List[GossipForwardingMessage]: """Return a list of gossip messages matching the requested timespan.""" forwarding_gossip = [] with self.lock: chans = self._channels.copy() policies = self._policies.copy() nodes = self._nodes.copy() for short_id, chan in chans.items(): # fetching the timestamp from the channel update (according to BOLT-07) chan_up_n1 = policies.get((chan.node1_id, short_id)) chan_up_n2 = policies.get((chan.node2_id, short_id)) updates = [] for policy in [chan_up_n1, chan_up_n2]: if policy and policy.raw and timespan.in_range(policy.timestamp): if policy.message_flags & 0b10 == 0: # check that its not "dont_forward" updates.append(GossipForwardingMessage( msg=policy.raw, timestamp=policy.timestamp)) if not updates or chan.raw is None: continue chan_ann_ts = min(update.timestamp for update in updates) channel_announcement = GossipForwardingMessage(msg=chan.raw, timestamp=chan_ann_ts) forwarding_gossip.extend([channel_announcement] + updates) for node_ann in nodes.values(): if timespan.in_range(node_ann.timestamp) and node_ann.raw: forwarding_gossip.append(GossipForwardingMessage( msg=node_ann.raw, timestamp=node_ann.timestamp)) return forwarding_gossip def get_channels_in_range(self, first_blocknum: int, number_of_blocks: int) -> List[ShortChannelID]: with self.lock: channels = self._channels.copy() scids: List[ShortChannelID] = [] for scid in channels: if first_blocknum <= scid.block_height < first_blocknum + number_of_blocks: scids.append(scid) scids.sort() return scids def get_gossip_for_scid_request(self, scid: ShortChannelID) -> List[bytes]: requested_gossip = [] chan_ann = self._channels.get(scid) if not chan_ann or not chan_ann.raw: return [] chan_up1 = self._policies.get((chan_ann.node1_id, scid)) chan_up2 = self._policies.get((chan_ann.node2_id, scid)) node_ann1 = self._nodes.get(chan_ann.node1_id) node_ann2 = self._nodes.get(chan_ann.node2_id) for msg in [chan_ann, chan_up1, chan_up2, node_ann1, node_ann2]: if msg and msg.raw: requested_gossip.append(msg.raw) return requested_gossip def to_dict(self) -> dict: """ Generates a graph representation in terms of a dictionary. The dictionary contains only native python types and can be encoded to json. """ with self.lock: graph = {'nodes': [], 'channels': []} # gather nodes for pk, nodeinfo in self._nodes.items(): # use _asdict() to convert NamedTuples to json encodable dicts graph['nodes'].append( nodeinfo._asdict(), ) graph['nodes'][-1]['addresses'] = [ {'host': str(addr.host), 'port': addr.port, 'timestamp': ts} for addr, ts in self._addresses[pk].items() ] # gather channels for cid, channelinfo in self._channels.items(): graph['channels'].append( channelinfo._asdict(), ) policy1 = self._policies.get( (channelinfo.node1_id, channelinfo.short_channel_id)) policy2 = self._policies.get( (channelinfo.node2_id, channelinfo.short_channel_id)) graph['channels'][-1]['policy1'] = policy1._asdict() if policy1 else None graph['channels'][-1]['policy2'] = policy2._asdict() if policy2 else None # need to use json_normalize otherwise json encoding in rpc server fails graph = json_normalize(graph) return graph