`suggest_node_channel_open()` did suggest peers with onion hostname, even if the caller has no proxy enabled. This causes channel openings in the gui to sometimes just not work and show a `CancelledError()` becaues it wasn't able to connect to the peer. Now only clearnet peers will get recommended, as these will always work.
1182 lines
49 KiB
Python
1182 lines
49 KiB
Python
# -*- coding: utf-8 -*-
|
|
#
|
|
# Electrum - lightweight Bitcoin client
|
|
# Copyright (C) 2018 The Electrum developers
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person
|
|
# obtaining a copy of this software and associated documentation files
|
|
# (the "Software"), to deal in the Software without restriction,
|
|
# including without limitation the rights to use, copy, modify, merge,
|
|
# publish, distribute, sublicense, and/or sell copies of the Software,
|
|
# and to permit persons to whom the Software is furnished to do so,
|
|
# subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be
|
|
# included in all copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
|
|
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
|
|
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
import ipaddress
|
|
import time
|
|
import random
|
|
import os
|
|
from collections import defaultdict
|
|
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
|
|
import binascii
|
|
import base64
|
|
import asyncio
|
|
import threading
|
|
from enum import IntEnum
|
|
import functools
|
|
|
|
from aiorpcx import NetAddress
|
|
from electrum_ecc import ECPubkey
|
|
|
|
from .sql_db import SqlDB, sql
|
|
from . import constants, util
|
|
from .util import profiler, get_headers_dir, is_ip_address, json_normalize, UserFacingException, is_private_netaddress
|
|
from .lntransport import LNPeerAddr
|
|
from .lnutil import (ShortChannelID, validate_features, IncompatibleOrInsaneFeatures,
|
|
InvalidGossipMsg, GossipForwardingMessage, GossipTimestampFilter)
|
|
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
|
|
from .lnmsg import decode_msg
|
|
from .crypto import sha256d
|
|
from .lnmsg import FailedToParseMsg
|
|
|
|
if TYPE_CHECKING:
|
|
from .network import Network
|
|
from .lnchannel import Channel
|
|
from .lnrouter import RouteEdge
|
|
from .simple_config import SimpleConfig
|
|
|
|
|
|
FLAG_DISABLE = 1 << 1
|
|
FLAG_DIRECTION = 1 << 0
|
|
|
|
|
|
class ChannelDBNotLoaded(UserFacingException): pass
|
|
|
|
|
|
class ChannelInfo(NamedTuple):
|
|
short_channel_id: ShortChannelID
|
|
node1_id: bytes
|
|
node2_id: bytes
|
|
capacity_sat: Optional[int]
|
|
raw: Optional[bytes] = None
|
|
|
|
@staticmethod
|
|
def from_msg(payload: dict) -> 'ChannelInfo':
|
|
features = int.from_bytes(payload['features'], 'big')
|
|
features = validate_features(features)
|
|
channel_id = payload['short_channel_id']
|
|
node_id_1 = payload['node_id_1']
|
|
node_id_2 = payload['node_id_2']
|
|
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
|
|
capacity_sat = None
|
|
return ChannelInfo(
|
|
short_channel_id = ShortChannelID.normalize(channel_id),
|
|
node1_id = node_id_1,
|
|
node2_id = node_id_2,
|
|
capacity_sat = capacity_sat,
|
|
raw = payload.get('raw')
|
|
)
|
|
|
|
@staticmethod
|
|
def from_raw_msg(raw: bytes) -> 'ChannelInfo':
|
|
payload_dict = decode_msg(raw)[1]
|
|
payload_dict['raw'] = raw
|
|
return ChannelInfo.from_msg(payload_dict)
|
|
|
|
@staticmethod
|
|
def from_route_edge(route_edge: 'RouteEdge') -> 'ChannelInfo':
|
|
node1_id, node2_id = sorted([route_edge.start_node, route_edge.end_node])
|
|
return ChannelInfo(
|
|
short_channel_id=route_edge.short_channel_id,
|
|
node1_id=node1_id,
|
|
node2_id=node2_id,
|
|
capacity_sat=None,
|
|
)
|
|
|
|
|
|
class Policy(NamedTuple):
|
|
key: bytes
|
|
cltv_delta: int
|
|
htlc_minimum_msat: int
|
|
htlc_maximum_msat: Optional[int]
|
|
fee_base_msat: int
|
|
fee_proportional_millionths: int
|
|
channel_flags: int
|
|
message_flags: int
|
|
timestamp: int
|
|
raw: Optional[bytes] = None
|
|
|
|
@staticmethod
|
|
def from_msg(payload: dict) -> 'Policy':
|
|
return Policy(
|
|
key = payload['short_channel_id'] + payload['start_node'],
|
|
cltv_delta = payload['cltv_expiry_delta'],
|
|
htlc_minimum_msat = payload['htlc_minimum_msat'],
|
|
htlc_maximum_msat = payload.get('htlc_maximum_msat', None),
|
|
fee_base_msat = payload['fee_base_msat'],
|
|
fee_proportional_millionths = payload['fee_proportional_millionths'],
|
|
message_flags = int.from_bytes(payload['message_flags'], "big"),
|
|
channel_flags = int.from_bytes(payload['channel_flags'], "big"),
|
|
timestamp = payload['timestamp'],
|
|
raw = payload.get('raw'),
|
|
)
|
|
|
|
@staticmethod
|
|
def from_raw_msg(key: bytes, raw: bytes) -> 'Policy':
|
|
payload = decode_msg(raw)[1]
|
|
payload['start_node'] = key[8:]
|
|
payload['raw'] = raw
|
|
return Policy.from_msg(payload)
|
|
|
|
@staticmethod
|
|
def from_route_edge(route_edge: 'RouteEdge') -> 'Policy':
|
|
return Policy(
|
|
key=route_edge.short_channel_id + route_edge.start_node,
|
|
cltv_delta=route_edge.cltv_delta,
|
|
htlc_minimum_msat=0,
|
|
htlc_maximum_msat=None,
|
|
fee_base_msat=route_edge.fee_base_msat,
|
|
fee_proportional_millionths=route_edge.fee_proportional_millionths,
|
|
channel_flags=0,
|
|
message_flags=0,
|
|
timestamp=0,
|
|
)
|
|
|
|
def is_disabled(self):
|
|
return self.channel_flags & FLAG_DISABLE
|
|
|
|
@property
|
|
def short_channel_id(self) -> ShortChannelID:
|
|
return ShortChannelID.normalize(self.key[0:8])
|
|
|
|
@property
|
|
def start_node(self) -> bytes:
|
|
return self.key[8:]
|
|
|
|
|
|
class NodeInfo(NamedTuple):
|
|
node_id: bytes
|
|
features: int
|
|
timestamp: int
|
|
alias: str
|
|
raw: Optional[bytes]
|
|
|
|
@staticmethod
|
|
def from_msg(payload) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]:
|
|
node_id = payload['node_id']
|
|
features = int.from_bytes(payload['features'], "big")
|
|
features = validate_features(features)
|
|
addresses = NodeInfo.parse_addresses_field(payload['addresses'])
|
|
peer_addrs = []
|
|
for host, port in addresses:
|
|
try:
|
|
peer_addrs.append(LNPeerAddr(host=host, port=port, pubkey=node_id))
|
|
except ValueError:
|
|
pass
|
|
alias = payload['alias'].rstrip(b'\x00')
|
|
try:
|
|
alias = alias.decode('utf8')
|
|
except Exception:
|
|
alias = ''
|
|
timestamp = payload['timestamp']
|
|
node_info = NodeInfo(
|
|
node_id=node_id,
|
|
features=features,
|
|
timestamp=timestamp,
|
|
alias=alias,
|
|
raw=payload.get('raw'))
|
|
return node_info, peer_addrs
|
|
|
|
@staticmethod
|
|
def from_raw_msg(raw: bytes) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]:
|
|
payload_dict = decode_msg(raw)[1]
|
|
payload_dict['raw'] = raw
|
|
return NodeInfo.from_msg(payload_dict)
|
|
|
|
@staticmethod
|
|
def to_addresses_field(hostname: str, port: int) -> bytes:
|
|
"""Encodes a hostname/port pair into a BOLT-7 'addresses' field."""
|
|
if (NodeInfo.invalid_announcement_hostname(hostname)
|
|
or port is None or port <= 0 or port > 65535):
|
|
return b''
|
|
port_bytes = port.to_bytes(2, 'big')
|
|
if is_ip_address(hostname): # ipv4 or ipv6
|
|
ip_addr = ipaddress.ip_address(hostname)
|
|
if ip_addr.version == 4:
|
|
return b'\x01' + ip_addr.packed + port_bytes
|
|
elif ip_addr.version == 6:
|
|
return b'\x02' + ip_addr.packed + port_bytes
|
|
return b''
|
|
elif hostname.endswith('.onion'): # Tor onion v3
|
|
onion_addr: bytes = base64.b32decode(hostname[:-6], casefold=True)
|
|
return b'\x04' + onion_addr + port_bytes
|
|
else:
|
|
try:
|
|
hostname_ascii: bytes = hostname.encode('ascii')
|
|
except UnicodeEncodeError:
|
|
# encoding single characters to punycode (according to spec) doesn't make sense
|
|
# as you can't differentiate them from regular ascii? encoding the whole string to punycode
|
|
# doesn't work either as the receiver would interpret it as regular ascii.
|
|
# hostname_ascii: bytes = hostname.encode('punycode')
|
|
return b''
|
|
if len(hostname_ascii) + 3 > 258: # + 1 byte for length and 2 for port
|
|
return b'' # too long
|
|
return b'\x05' + len(hostname_ascii).to_bytes(1, "big") + hostname_ascii + port_bytes
|
|
|
|
@staticmethod
|
|
def invalid_announcement_hostname(hostname: Optional[str]) -> bool:
|
|
"""Returns True if hostname unsuited for publishing in a NodeAnnouncement."""
|
|
if (hostname is None or hostname == ""
|
|
or is_private_netaddress(hostname)
|
|
or hostname.startswith("http://") # not catching 'http' due to onion addresses
|
|
or hostname.startswith("https://")):
|
|
return True
|
|
if hostname.endswith('.onion'):
|
|
if len(hostname) != 62: # not an onion v3 link (probably onion v2)
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def parse_addresses_field(addresses_field):
|
|
buf = addresses_field
|
|
|
|
def read(n):
|
|
nonlocal buf
|
|
data, buf = buf[0:n], buf[n:]
|
|
return data
|
|
|
|
addresses = []
|
|
while buf:
|
|
atype = ord(read(1))
|
|
if atype == 0:
|
|
pass
|
|
elif atype == 1: # IPv4
|
|
ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
|
|
port = int.from_bytes(read(2), 'big')
|
|
if is_ip_address(ipv4_addr) and port != 0:
|
|
addresses.append((ipv4_addr, port))
|
|
elif atype == 2: # IPv6
|
|
ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)])
|
|
ipv6_addr = ipv6_addr.decode('ascii')
|
|
port = int.from_bytes(read(2), 'big')
|
|
if is_ip_address(ipv6_addr) and port != 0:
|
|
addresses.append((ipv6_addr, port))
|
|
elif atype == 3: # onion v2
|
|
read(12) # we skip onion v2 as it is deprecated
|
|
elif atype == 4: # onion v3
|
|
host = base64.b32encode(read(35)) + b'.onion'
|
|
host = host.decode('ascii').lower()
|
|
port = int.from_bytes(read(2), 'big')
|
|
addresses.append((host, port))
|
|
elif atype == 5: # dns hostname
|
|
len_hostname = int.from_bytes(read(1), 'big')
|
|
host = read(len_hostname).decode('ascii')
|
|
port = int.from_bytes(read(2), 'big')
|
|
if not NodeInfo.invalid_announcement_hostname(host) and port > 0:
|
|
addresses.append((host, port))
|
|
else:
|
|
# unknown address type
|
|
# we don't know how long it is -> have to escape
|
|
# if there are other addresses we could have parsed later, they are lost.
|
|
break
|
|
return addresses
|
|
|
|
|
|
class UpdateStatus(IntEnum):
|
|
ORPHANED = 0
|
|
EXPIRED = 1
|
|
DEPRECATED = 2
|
|
UNCHANGED = 3
|
|
GOOD = 4
|
|
|
|
|
|
class CategorizedChannelUpdates(NamedTuple):
|
|
orphaned: List # no channel announcement for channel update
|
|
expired: List # update older than two weeks
|
|
deprecated: List # update older than database entry
|
|
unchanged: List # unchanged policies
|
|
good: List # good updates
|
|
|
|
|
|
def get_mychannel_info(short_channel_id: ShortChannelID,
|
|
my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]:
|
|
chan = my_channels.get(short_channel_id)
|
|
if not chan:
|
|
return
|
|
raw_msg, _ = chan.construct_channel_announcement_without_sigs()
|
|
ci = ChannelInfo.from_raw_msg(raw_msg)
|
|
return ci._replace(capacity_sat=chan.constraints.capacity)
|
|
|
|
|
|
def get_mychannel_policy(short_channel_id: bytes, node_id: bytes,
|
|
my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[Policy]:
|
|
chan = my_channels.get(short_channel_id) # type: Optional[Channel]
|
|
if not chan:
|
|
return
|
|
if node_id == chan.node_id: # incoming direction (to us)
|
|
remote_update_raw = chan.get_remote_update()
|
|
if not remote_update_raw:
|
|
return
|
|
now = int(time.time())
|
|
remote_update_decoded = decode_msg(remote_update_raw)[1]
|
|
remote_update_decoded['timestamp'] = now
|
|
remote_update_decoded['start_node'] = node_id
|
|
return Policy.from_msg(remote_update_decoded)
|
|
elif node_id == chan.get_local_pubkey(): # outgoing direction (from us)
|
|
local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1]
|
|
local_update_decoded['start_node'] = node_id
|
|
return Policy.from_msg(local_update_decoded)
|
|
|
|
|
|
class _LoadDataAborted(Exception): pass
|
|
|
|
|
|
create_channel_info = """
|
|
CREATE TABLE IF NOT EXISTS channel_info (
|
|
short_channel_id BLOB(8),
|
|
msg BLOB,
|
|
PRIMARY KEY(short_channel_id)
|
|
)"""
|
|
|
|
create_policy = """
|
|
CREATE TABLE IF NOT EXISTS policy (
|
|
key BLOB(41),
|
|
msg BLOB,
|
|
PRIMARY KEY(key)
|
|
)"""
|
|
|
|
create_address = """
|
|
CREATE TABLE IF NOT EXISTS address (
|
|
node_id BLOB(33),
|
|
host STRING(256),
|
|
port INTEGER NOT NULL,
|
|
timestamp INTEGER,
|
|
PRIMARY KEY(node_id, host, port)
|
|
)"""
|
|
|
|
create_node_info = """
|
|
CREATE TABLE IF NOT EXISTS node_info (
|
|
node_id BLOB(33),
|
|
msg BLOB,
|
|
PRIMARY KEY(node_id)
|
|
)"""
|
|
|
|
|
|
class ChannelDB(SqlDB):
|
|
|
|
NUM_MAX_RECENT_PEERS = 20
|
|
PRIVATE_CHAN_UPD_CACHE_TTL_NORMAL = 600
|
|
PRIVATE_CHAN_UPD_CACHE_TTL_SHORT = 120
|
|
|
|
def __init__(self, network: 'Network'):
|
|
path = self.get_file_path(network.config)
|
|
super().__init__(network.asyncio_loop, path, commit_interval=100)
|
|
self.lock = threading.RLock()
|
|
self.num_nodes = 0
|
|
self.num_channels = 0
|
|
self.num_policies = 0
|
|
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], Tuple[dict, int]]
|
|
# note: ^ we could maybe move this cache into PaySession instead of being global.
|
|
# That would only make sense though if PaySessions were never too short
|
|
# (e.g. consider trampoline forwarding).
|
|
self.ca_verifier = LNChannelVerifier(network, self)
|
|
|
|
# initialized in load_data
|
|
# note: modify/iterate needs self.lock
|
|
self._channels = {} # type: Dict[ShortChannelID, ChannelInfo]
|
|
self._policies = {} # type: Dict[Tuple[bytes, ShortChannelID], Policy] # (node_id, scid) -> Policy
|
|
self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo
|
|
# node_id -> NetAddress -> timestamp
|
|
self._addresses = defaultdict(dict) # type: Dict[bytes, Dict[NetAddress, int]]
|
|
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
|
|
self._recent_peers = [] # type: List[bytes] # list of node_ids
|
|
self._chans_with_0_policies = set() # type: Set[ShortChannelID]
|
|
self._chans_with_1_policies = set() # type: Set[ShortChannelID]
|
|
self._chans_with_2_policies = set() # type: Set[ShortChannelID]
|
|
|
|
self.forwarding_lock = threading.RLock()
|
|
self.fwd_channels = [] # type: List[GossipForwardingMessage]
|
|
self.fwd_orphan_channels = [] # type: List[GossipForwardingMessage]
|
|
self.fwd_channel_updates = [] # type: List[GossipForwardingMessage]
|
|
self.fwd_node_announcements = [] # type: List[GossipForwardingMessage]
|
|
|
|
self.data_loaded = asyncio.Event()
|
|
self.network = network # only for callback
|
|
|
|
@classmethod
|
|
def get_file_path(cls, config: 'SimpleConfig') -> str:
|
|
return os.path.join(get_headers_dir(config), 'gossip_db')
|
|
|
|
def update_counts(self):
|
|
self.num_nodes = len(self._nodes)
|
|
self.num_channels = len(self._channels)
|
|
self.num_policies = len(self._policies)
|
|
util.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies)
|
|
util.trigger_callback('ln_gossip_sync_progress')
|
|
|
|
def get_channel_ids(self):
|
|
with self.lock:
|
|
return set(self._channels.keys())
|
|
|
|
def add_recent_peer(self, peer: LNPeerAddr):
|
|
now = int(time.time())
|
|
node_id = peer.pubkey
|
|
with self.lock:
|
|
self._addresses[node_id][peer.net_addr()] = now
|
|
# list is ordered
|
|
if node_id in self._recent_peers:
|
|
self._recent_peers.remove(node_id)
|
|
self._recent_peers.insert(0, node_id)
|
|
self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS]
|
|
self._db_save_node_address(peer, now)
|
|
|
|
def get_200_randomly_sorted_nodes_not_in(self, node_ids):
|
|
with self.lock:
|
|
unshuffled = set(self._nodes.keys()) - node_ids
|
|
return random.sample(list(unshuffled), min(200, len(unshuffled)))
|
|
|
|
def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
|
|
"""Returns latest address we successfully connected to, for given node."""
|
|
addr_to_ts = self._addresses.get(node_id)
|
|
if not addr_to_ts:
|
|
return None
|
|
addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0]
|
|
try:
|
|
return LNPeerAddr(str(addr.host), addr.port, node_id)
|
|
except ValueError:
|
|
return None
|
|
|
|
def get_recent_peers(self):
|
|
if not self.data_loaded.is_set():
|
|
raise ChannelDBNotLoaded("channelDB data not loaded yet!")
|
|
with self.lock:
|
|
ret = [self.get_last_good_address(node_id)
|
|
for node_id in self._recent_peers]
|
|
return ret
|
|
|
|
# note: currently channel announcements are trusted by default (trusted=True);
|
|
# they are not SPV-verified. Verifying them would make the gossip sync
|
|
# even slower; especially as servers will start throttling us.
|
|
# It would probably put significant strain on servers if all clients
|
|
# verified the complete gossip.
|
|
def add_channel_announcements(self, msg_payloads, *, trusted=True):
|
|
# note: signatures have already been verified.
|
|
if type(msg_payloads) is dict:
|
|
msg_payloads = [msg_payloads]
|
|
added = 0
|
|
for msg in msg_payloads:
|
|
short_channel_id = ShortChannelID(msg['short_channel_id'])
|
|
if short_channel_id in self._channels:
|
|
continue
|
|
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
|
|
self.logger.info("ChanAnn has unexpected chain_hash {}".format(msg['chain_hash'].hex()))
|
|
continue
|
|
try:
|
|
channel_info = ChannelInfo.from_msg(msg)
|
|
except IncompatibleOrInsaneFeatures as e:
|
|
self.logger.info(f"unknown or insane feature bits: {e!r}")
|
|
continue
|
|
if trusted:
|
|
added += 1
|
|
self.add_verified_channel_info(msg)
|
|
else:
|
|
added += self.ca_verifier.add_new_channel_info(short_channel_id, msg)
|
|
|
|
self.update_counts()
|
|
|
|
def add_verified_channel_info(self, msg: dict, *, capacity_sat: int = None) -> None:
|
|
try:
|
|
channel_info = ChannelInfo.from_msg(msg)
|
|
except IncompatibleOrInsaneFeatures:
|
|
return
|
|
channel_info = channel_info._replace(capacity_sat=capacity_sat)
|
|
with self.lock:
|
|
self._channels[channel_info.short_channel_id] = channel_info
|
|
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
|
|
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
|
|
self._update_num_policies_for_chan(channel_info.short_channel_id)
|
|
if 'raw' in msg:
|
|
self._db_save_channel(channel_info.short_channel_id, msg['raw'])
|
|
with self.forwarding_lock:
|
|
if fwd_msg := GossipForwardingMessage.from_payload(msg):
|
|
self.fwd_channels.append(fwd_msg)
|
|
|
|
def policy_changed(self, old_policy: Policy, new_policy: Policy, verbose: bool) -> bool:
|
|
changed = False
|
|
if old_policy.cltv_delta != new_policy.cltv_delta:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_delta} -> {new_policy.cltv_delta}')
|
|
if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
|
|
if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
|
|
if old_policy.fee_base_msat != new_policy.fee_base_msat:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
|
|
if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
|
|
if old_policy.channel_flags != new_policy.channel_flags:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
|
|
if old_policy.message_flags != new_policy.message_flags:
|
|
changed |= True
|
|
if verbose:
|
|
self.logger.info(f'message_flags: {old_policy.message_flags} -> {new_policy.message_flags}')
|
|
if not changed and verbose:
|
|
self.logger.info(f'policy unchanged: {old_policy.timestamp} -> {new_policy.timestamp}')
|
|
return changed
|
|
|
|
def add_channel_update(
|
|
self, payload, *, max_age=None, verify=True, verbose=True) -> UpdateStatus:
|
|
now = int(time.time())
|
|
short_channel_id = ShortChannelID(payload['short_channel_id'])
|
|
timestamp = payload['timestamp']
|
|
if max_age and now - timestamp > max_age:
|
|
return UpdateStatus.EXPIRED
|
|
if timestamp - now > 60:
|
|
return UpdateStatus.DEPRECATED
|
|
channel_info = self._channels.get(short_channel_id)
|
|
if not channel_info:
|
|
return UpdateStatus.ORPHANED
|
|
flags = int.from_bytes(payload['channel_flags'], 'big')
|
|
direction = flags & FLAG_DIRECTION
|
|
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
|
|
payload['start_node'] = start_node
|
|
# compare updates to existing database entries
|
|
short_channel_id = ShortChannelID(payload['short_channel_id'])
|
|
key = (start_node, short_channel_id)
|
|
old_policy = self._policies.get(key)
|
|
if old_policy and timestamp <= old_policy.timestamp + 60:
|
|
return UpdateStatus.DEPRECATED
|
|
if verify:
|
|
self.verify_channel_update(payload)
|
|
policy = Policy.from_msg(payload)
|
|
with self.lock:
|
|
self._policies[key] = policy
|
|
self._update_num_policies_for_chan(short_channel_id)
|
|
if 'raw' in payload:
|
|
self._db_save_policy(policy.key, payload['raw'])
|
|
if old_policy and not self.policy_changed(old_policy, policy, verbose):
|
|
return UpdateStatus.UNCHANGED
|
|
else:
|
|
if policy.message_flags & 0b10 == 0: # check if its `dont_forward`
|
|
with self.forwarding_lock:
|
|
if fwd_msg := GossipForwardingMessage.from_payload(payload):
|
|
self.fwd_channel_updates.append(fwd_msg)
|
|
return UpdateStatus.GOOD
|
|
|
|
def add_channel_updates(self, payloads, max_age=None) -> CategorizedChannelUpdates:
|
|
orphaned = []
|
|
expired = []
|
|
deprecated = []
|
|
unchanged = []
|
|
good = []
|
|
for payload in payloads:
|
|
r = self.add_channel_update(payload, max_age=max_age, verbose=False, verify=True)
|
|
if r == UpdateStatus.ORPHANED:
|
|
orphaned.append(payload)
|
|
elif r == UpdateStatus.EXPIRED:
|
|
expired.append(payload)
|
|
elif r == UpdateStatus.DEPRECATED:
|
|
deprecated.append(payload)
|
|
elif r == UpdateStatus.UNCHANGED:
|
|
unchanged.append(payload)
|
|
elif r == UpdateStatus.GOOD:
|
|
good.append(payload)
|
|
self.update_counts()
|
|
return CategorizedChannelUpdates(
|
|
orphaned=orphaned,
|
|
expired=expired,
|
|
deprecated=deprecated,
|
|
unchanged=unchanged,
|
|
good=good)
|
|
|
|
def create_database(self):
|
|
c = self.conn.cursor()
|
|
c.execute(create_node_info)
|
|
c.execute(create_address)
|
|
c.execute(create_policy)
|
|
c.execute(create_channel_info)
|
|
self.conn.commit()
|
|
|
|
@sql
|
|
def _db_save_policy(self, key: bytes, msg: bytes):
|
|
# 'msg' is a 'channel_update' message
|
|
c = self.conn.cursor()
|
|
c.execute("""REPLACE INTO policy (key, msg) VALUES (?,?)""", [key, msg])
|
|
|
|
@sql
|
|
def _db_delete_policy(self, node_id: bytes, short_channel_id: ShortChannelID):
|
|
key = short_channel_id + node_id
|
|
c = self.conn.cursor()
|
|
c.execute("""DELETE FROM policy WHERE key=?""", (key,))
|
|
|
|
@sql
|
|
def _db_save_channel(self, short_channel_id: ShortChannelID, msg: bytes):
|
|
# 'msg' is a 'channel_announcement' message
|
|
c = self.conn.cursor()
|
|
c.execute("REPLACE INTO channel_info (short_channel_id, msg) VALUES (?,?)", [short_channel_id, msg])
|
|
|
|
@sql
|
|
def _db_delete_channel(self, short_channel_id: ShortChannelID):
|
|
c = self.conn.cursor()
|
|
c.execute("""DELETE FROM channel_info WHERE short_channel_id=?""", (short_channel_id,))
|
|
|
|
@sql
|
|
def _db_save_node_info(self, node_id: bytes, msg: bytes):
|
|
# 'msg' is a 'node_announcement' message
|
|
c = self.conn.cursor()
|
|
c.execute("REPLACE INTO node_info (node_id, msg) VALUES (?,?)", [node_id, msg])
|
|
|
|
@sql
|
|
def _db_save_node_address(self, peer: LNPeerAddr, timestamp: int):
|
|
c = self.conn.cursor()
|
|
c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)",
|
|
(peer.pubkey, peer.host, peer.port, timestamp))
|
|
|
|
@sql
|
|
def _db_save_node_addresses(self, node_addresses: Sequence[LNPeerAddr]):
|
|
c = self.conn.cursor()
|
|
for addr in node_addresses:
|
|
c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.pubkey, addr.host, addr.port))
|
|
r = c.fetchall()
|
|
if r == []:
|
|
c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.pubkey, addr.host, addr.port, 0))
|
|
|
|
@classmethod
|
|
def verify_channel_update(cls, payload, *, start_node: bytes = None) -> None:
|
|
short_channel_id = payload['short_channel_id']
|
|
short_channel_id = ShortChannelID(short_channel_id)
|
|
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
|
|
raise InvalidGossipMsg('wrong chain hash')
|
|
start_node = payload.get('start_node', None) or start_node
|
|
assert start_node is not None
|
|
if not verify_sig_for_channel_update(payload, start_node):
|
|
raise InvalidGossipMsg(f'failed verifying channel update for {short_channel_id}')
|
|
|
|
@classmethod
|
|
def verify_channel_announcement(cls, payload) -> None:
|
|
h = sha256d(payload['raw'][2+256:])
|
|
pubkeys = [payload['node_id_1'], payload['node_id_2'], payload['bitcoin_key_1'], payload['bitcoin_key_2']]
|
|
sigs = [payload['node_signature_1'], payload['node_signature_2'], payload['bitcoin_signature_1'], payload['bitcoin_signature_2']]
|
|
for pubkey, sig in zip(pubkeys, sigs):
|
|
if not ECPubkey(pubkey).ecdsa_verify(sig, h):
|
|
raise InvalidGossipMsg('signature failed')
|
|
|
|
@classmethod
|
|
def verify_node_announcement(cls, payload) -> None:
|
|
pubkey = payload['node_id']
|
|
signature = payload['signature']
|
|
h = sha256d(payload['raw'][66:])
|
|
if not ECPubkey(pubkey).ecdsa_verify(signature, h):
|
|
raise InvalidGossipMsg('signature failed')
|
|
|
|
def add_node_announcements(self, msg_payloads):
|
|
# note: signatures have already been verified.
|
|
if type(msg_payloads) is dict:
|
|
msg_payloads = [msg_payloads]
|
|
new_nodes = set() # type: Set[bytes]
|
|
for msg_payload in msg_payloads:
|
|
try:
|
|
node_info, node_addresses = NodeInfo.from_msg(msg_payload)
|
|
except IncompatibleOrInsaneFeatures:
|
|
continue
|
|
node_id = node_info.node_id
|
|
# Ignore node if it has no associated channel (DoS protection)
|
|
if node_id not in self._channels_for_node:
|
|
#self.logger.info('ignoring orphan node_announcement')
|
|
continue
|
|
node = self._nodes.get(node_id)
|
|
if node and node.timestamp >= node_info.timestamp:
|
|
continue
|
|
new_nodes.add(node_id)
|
|
# save
|
|
with self.lock:
|
|
self._nodes[node_id] = node_info
|
|
if 'raw' in msg_payload:
|
|
self._db_save_node_info(node_id, msg_payload['raw'])
|
|
with self.lock:
|
|
for addr in node_addresses:
|
|
net_addr = NetAddress(addr.host, addr.port)
|
|
self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0
|
|
self._db_save_node_addresses(node_addresses)
|
|
with self.forwarding_lock:
|
|
if fwd_msg := GossipForwardingMessage.from_payload(msg_payload):
|
|
self.fwd_node_announcements.append(fwd_msg)
|
|
|
|
self.update_counts()
|
|
|
|
def get_old_policies(self, delta) -> Sequence[Tuple[bytes, ShortChannelID]]:
|
|
with self.lock:
|
|
_policies = self._policies.copy()
|
|
now = int(time.time())
|
|
return list(k for k, v in _policies.items() if v.timestamp <= now - delta)
|
|
|
|
def prune_old_policies(self, delta):
|
|
old_policies = self.get_old_policies(delta)
|
|
if old_policies:
|
|
for key in old_policies:
|
|
node_id, scid = key
|
|
with self.lock:
|
|
self._policies.pop(key)
|
|
self._db_delete_policy(*key)
|
|
self._update_num_policies_for_chan(scid)
|
|
self.update_counts()
|
|
self.logger.info(f'Deleting {len(old_policies)} old policies')
|
|
|
|
def prune_orphaned_channels(self):
|
|
with self.lock:
|
|
orphaned_chans = self._chans_with_0_policies.copy()
|
|
if orphaned_chans:
|
|
for short_channel_id in orphaned_chans:
|
|
self.remove_channel(short_channel_id)
|
|
self.update_counts()
|
|
self.logger.info(f'Deleting {len(orphaned_chans)} orphaned channels')
|
|
|
|
def _get_channel_update_for_private_channel(
|
|
self,
|
|
start_node_id: bytes,
|
|
short_channel_id: ShortChannelID,
|
|
*,
|
|
now: int = None, # unix ts
|
|
) -> Optional[dict]:
|
|
if now is None:
|
|
now = int(time.time())
|
|
key = (start_node_id, short_channel_id)
|
|
chan_upd_dict, cache_expiration = self._channel_updates_for_private_channels.get(key, (None, 0))
|
|
if cache_expiration < now:
|
|
chan_upd_dict = None # already expired
|
|
# TODO rm expired entries from cache (note: perf vs thread-safety)
|
|
return chan_upd_dict
|
|
|
|
def add_channel_update_for_private_channel(
|
|
self,
|
|
msg_payload: dict,
|
|
start_node_id: bytes,
|
|
*,
|
|
cache_ttl: int = None, # seconds
|
|
) -> bool:
|
|
"""Returns True iff the channel update was successfully added and it was different than
|
|
what we had before (if any).
|
|
"""
|
|
if not verify_sig_for_channel_update(msg_payload, start_node_id):
|
|
return False # ignore
|
|
now = int(time.time())
|
|
short_channel_id = ShortChannelID(msg_payload['short_channel_id'])
|
|
msg_payload['start_node'] = start_node_id
|
|
prev_chanupd = self._get_channel_update_for_private_channel(start_node_id, short_channel_id, now=now)
|
|
if prev_chanupd == msg_payload:
|
|
return False
|
|
if cache_ttl is None:
|
|
cache_ttl = self.PRIVATE_CHAN_UPD_CACHE_TTL_NORMAL
|
|
cache_expiration = now + cache_ttl
|
|
key = (start_node_id, short_channel_id)
|
|
with self.lock:
|
|
self._channel_updates_for_private_channels[key] = msg_payload, cache_expiration
|
|
return True
|
|
|
|
def remove_channel(self, short_channel_id: ShortChannelID):
|
|
# FIXME what about rm-ing policies?
|
|
with self.lock:
|
|
channel_info = self._channels.pop(short_channel_id, None)
|
|
if channel_info:
|
|
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
|
|
self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
|
|
self._update_num_policies_for_chan(short_channel_id)
|
|
# delete from database
|
|
self._db_delete_channel(short_channel_id)
|
|
|
|
def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]:
|
|
"""Returns list of (host, port, timestamp)."""
|
|
addr_to_ts = self._addresses.get(node_id)
|
|
if not addr_to_ts:
|
|
return []
|
|
return [(str(net_addr.host), net_addr.port, ts)
|
|
for net_addr, ts in addr_to_ts.items()]
|
|
|
|
def handle_abort(func):
|
|
@functools.wraps(func)
|
|
def wrapper(self: 'ChannelDB', *args, **kwargs):
|
|
try:
|
|
return func(self, *args, **kwargs)
|
|
except _LoadDataAborted:
|
|
return
|
|
return wrapper
|
|
|
|
@sql
|
|
@profiler
|
|
@handle_abort
|
|
def load_data(self):
|
|
if self.data_loaded.is_set():
|
|
return
|
|
|
|
# Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow.
|
|
def maybe_abort():
|
|
if self.stopping:
|
|
self.logger.info("load_data() was asked to stop. exiting early.")
|
|
raise _LoadDataAborted()
|
|
c = self.conn.cursor()
|
|
c.execute("""SELECT * FROM address""")
|
|
for x in c:
|
|
maybe_abort()
|
|
node_id, host, port, timestamp = x
|
|
try:
|
|
net_addr = NetAddress(host, port)
|
|
except Exception:
|
|
continue
|
|
self._addresses[node_id][net_addr] = int(timestamp or 0)
|
|
|
|
def newest_ts_for_node_id(node_id):
|
|
newest_ts = 0
|
|
for addr, ts in self._addresses[node_id].items():
|
|
newest_ts = max(newest_ts, ts)
|
|
return newest_ts
|
|
sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True)
|
|
self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
|
|
c.execute("""SELECT * FROM channel_info""")
|
|
for short_channel_id, msg in c:
|
|
maybe_abort()
|
|
try:
|
|
ci = ChannelInfo.from_raw_msg(msg)
|
|
except IncompatibleOrInsaneFeatures:
|
|
continue
|
|
except FailedToParseMsg:
|
|
continue
|
|
self._channels[ShortChannelID.normalize(short_channel_id)] = ci
|
|
c.execute("""SELECT * FROM node_info""")
|
|
for node_id, msg in c:
|
|
maybe_abort()
|
|
try:
|
|
node_info, node_addresses = NodeInfo.from_raw_msg(msg)
|
|
except IncompatibleOrInsaneFeatures:
|
|
continue
|
|
except FailedToParseMsg:
|
|
continue
|
|
# don't load node_addresses because they dont have timestamps
|
|
self._nodes[node_id] = node_info
|
|
c.execute("""SELECT * FROM policy""")
|
|
for key, msg in c:
|
|
maybe_abort()
|
|
try:
|
|
p = Policy.from_raw_msg(key, msg)
|
|
except FailedToParseMsg:
|
|
continue
|
|
self._policies[(p.start_node, p.short_channel_id)] = p
|
|
for channel_info in self._channels.values():
|
|
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
|
|
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
|
|
self._update_num_policies_for_chan(channel_info.short_channel_id)
|
|
self.logger.info(f'data loaded. {len(self._channels)} chans. {len(self._policies)} policies. '
|
|
f'{len(self._channels_for_node)} nodes.')
|
|
self.update_counts()
|
|
(nchans_with_0p, nchans_with_1p, nchans_with_2p) = self.get_num_channels_partitioned_by_policy_count()
|
|
self.logger.info(f'num_channels_partitioned_by_policy_count. '
|
|
f'0p: {nchans_with_0p}, 1p: {nchans_with_1p}, 2p: {nchans_with_2p}')
|
|
self.asyncio_loop.call_soon_threadsafe(self.data_loaded.set)
|
|
util.trigger_callback('gossip_db_loaded')
|
|
|
|
def _update_num_policies_for_chan(self, short_channel_id: ShortChannelID) -> None:
|
|
channel_info = self.get_channel_info(short_channel_id)
|
|
if channel_info is None:
|
|
with self.lock:
|
|
self._chans_with_0_policies.discard(short_channel_id)
|
|
self._chans_with_1_policies.discard(short_channel_id)
|
|
self._chans_with_2_policies.discard(short_channel_id)
|
|
return
|
|
p1 = self.get_policy_for_node(short_channel_id, channel_info.node1_id)
|
|
p2 = self.get_policy_for_node(short_channel_id, channel_info.node2_id)
|
|
with self.lock:
|
|
self._chans_with_0_policies.discard(short_channel_id)
|
|
self._chans_with_1_policies.discard(short_channel_id)
|
|
self._chans_with_2_policies.discard(short_channel_id)
|
|
if p1 is not None and p2 is not None:
|
|
self._chans_with_2_policies.add(short_channel_id)
|
|
elif p1 is None and p2 is None:
|
|
self._chans_with_0_policies.add(short_channel_id)
|
|
else:
|
|
self._chans_with_1_policies.add(short_channel_id)
|
|
|
|
def get_num_channels_partitioned_by_policy_count(self) -> Tuple[int, int, int]:
|
|
nchans_with_0p = len(self._chans_with_0_policies)
|
|
nchans_with_1p = len(self._chans_with_1_policies)
|
|
nchans_with_2p = len(self._chans_with_2_policies)
|
|
return nchans_with_0p, nchans_with_1p, nchans_with_2p
|
|
|
|
def get_policy_for_node(
|
|
self,
|
|
short_channel_id: ShortChannelID,
|
|
node_id: bytes,
|
|
*,
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
|
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
|
|
now: int = None, # unix ts
|
|
) -> Optional['Policy']:
|
|
channel_info = self.get_channel_info(short_channel_id)
|
|
if channel_info is not None: # publicly announced channel
|
|
policy = self._policies.get((node_id, short_channel_id))
|
|
if policy:
|
|
return policy
|
|
elif chan_upd_dict := self._get_channel_update_for_private_channel(node_id, short_channel_id, now=now):
|
|
return Policy.from_msg(chan_upd_dict)
|
|
# check if it's one of our own channels
|
|
if my_channels:
|
|
policy = get_mychannel_policy(short_channel_id, node_id, my_channels)
|
|
if policy:
|
|
return policy
|
|
if private_route_edges:
|
|
route_edge = private_route_edges.get(short_channel_id, None)
|
|
if route_edge:
|
|
return Policy.from_route_edge(route_edge)
|
|
|
|
def get_channel_info(
|
|
self,
|
|
short_channel_id: ShortChannelID,
|
|
*,
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
|
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
|
|
) -> Optional[ChannelInfo]:
|
|
ret = self._channels.get(short_channel_id)
|
|
if ret:
|
|
return ret
|
|
# check if it's one of our own channels
|
|
if my_channels:
|
|
channel_info = get_mychannel_info(short_channel_id, my_channels)
|
|
if channel_info:
|
|
return channel_info
|
|
if private_route_edges:
|
|
route_edge = private_route_edges.get(short_channel_id)
|
|
if route_edge:
|
|
return ChannelInfo.from_route_edge(route_edge)
|
|
|
|
def get_channels_for_node(
|
|
self,
|
|
node_id: bytes,
|
|
*,
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
|
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
|
|
) -> Set[ShortChannelID]:
|
|
"""Returns the set of short channel IDs where node_id is one of the channel participants."""
|
|
if not self.data_loaded.is_set():
|
|
raise ChannelDBNotLoaded("channelDB data not loaded yet!")
|
|
relevant_channels = self._channels_for_node.get(node_id) or set()
|
|
relevant_channels = set(relevant_channels) # copy
|
|
# add our own channels # TODO maybe slow?
|
|
if my_channels:
|
|
for chan in my_channels.values():
|
|
if node_id in (chan.node_id, chan.get_local_pubkey()):
|
|
relevant_channels.add(chan.short_channel_id)
|
|
# add private channels # TODO maybe slow?
|
|
if private_route_edges:
|
|
for route_edge in private_route_edges.values():
|
|
if node_id in (route_edge.start_node, route_edge.end_node):
|
|
relevant_channels.add(route_edge.short_channel_id)
|
|
return relevant_channels
|
|
|
|
def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *,
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[Tuple[bytes, bytes]]:
|
|
channel_info = self.get_channel_info(short_channel_id)
|
|
if channel_info is not None: # publicly announced channel
|
|
return channel_info.node1_id, channel_info.node2_id
|
|
# check if it's one of our own channels
|
|
if not my_channels:
|
|
return
|
|
chan = my_channels.get(short_channel_id) # type: Optional[Channel]
|
|
if not chan:
|
|
return
|
|
return chan.get_local_pubkey(), chan.node_id
|
|
|
|
def get_node_info_for_node_id(self, node_id: bytes) -> Optional['NodeInfo']:
|
|
return self._nodes.get(node_id)
|
|
|
|
def get_node_infos(self) -> Dict[bytes, NodeInfo]:
|
|
with self.lock:
|
|
return self._nodes.copy()
|
|
|
|
def get_node_policies(self) -> Dict[Tuple[bytes, ShortChannelID], Policy]:
|
|
with self.lock:
|
|
return self._policies.copy()
|
|
|
|
def get_node_by_prefix(self, prefix):
|
|
with self.lock:
|
|
for k in self._addresses.keys():
|
|
if k.startswith(prefix):
|
|
return k
|
|
raise Exception('node not found')
|
|
|
|
def clear_forwarding_gossip(self) -> None:
|
|
with self.forwarding_lock:
|
|
self.fwd_channels.clear()
|
|
self.fwd_channel_updates.clear()
|
|
self.fwd_node_announcements.clear()
|
|
|
|
def filter_orphan_channel_anns(
|
|
self, channel_anns: List[GossipForwardingMessage]
|
|
) -> Tuple[List, List]:
|
|
"""Check if the channel announcements we want to forward have at least 1 update"""
|
|
to_forward_anns = []
|
|
orphaned_channel_anns = []
|
|
for channel in channel_anns:
|
|
if channel.scid is None:
|
|
continue
|
|
elif (channel.scid in self._chans_with_1_policies
|
|
or channel.scid in self._chans_with_2_policies):
|
|
to_forward_anns.append(channel)
|
|
continue
|
|
orphaned_channel_anns.append(channel)
|
|
return to_forward_anns, orphaned_channel_anns
|
|
|
|
def set_fwd_channel_anns_ts(self, channel_anns: List[GossipForwardingMessage]) \
|
|
-> List[GossipForwardingMessage]:
|
|
"""Set the timestamps of the passed channel announcements from the corresponding policies"""
|
|
timestamped_chan_anns: List[GossipForwardingMessage] = []
|
|
with self.lock:
|
|
policies = self._policies.copy()
|
|
channels = self._channels.copy()
|
|
|
|
for chan_ann in channel_anns:
|
|
if chan_ann.timestamp is not None:
|
|
timestamped_chan_anns.append(chan_ann)
|
|
continue
|
|
|
|
scid = chan_ann.scid
|
|
if (channel_info := channels.get(scid)) is None:
|
|
continue
|
|
|
|
policy1 = policies.get((channel_info.node1_id, scid))
|
|
policy2 = policies.get((channel_info.node2_id, scid))
|
|
potential_timestamps = []
|
|
for policy in [policy1, policy2]:
|
|
if policy is not None:
|
|
potential_timestamps.append(policy.timestamp)
|
|
if not potential_timestamps:
|
|
continue
|
|
chan_ann.timestamp = min(potential_timestamps)
|
|
timestamped_chan_anns.append(chan_ann)
|
|
return timestamped_chan_anns
|
|
|
|
def get_forwarding_gossip_batch(self) -> List[GossipForwardingMessage]:
|
|
with self.forwarding_lock:
|
|
fwd_gossip = self.fwd_channel_updates + self.fwd_node_announcements
|
|
channel_anns = self.fwd_channels.copy()
|
|
self.clear_forwarding_gossip()
|
|
|
|
fwd_chan_anns1, _ = self.filter_orphan_channel_anns(self.fwd_orphan_channels)
|
|
fwd_chan_anns2, self.fwd_orphan_channels = self.filter_orphan_channel_anns(channel_anns)
|
|
channel_anns = self.set_fwd_channel_anns_ts(fwd_chan_anns1 + fwd_chan_anns2)
|
|
return channel_anns + fwd_gossip
|
|
|
|
def get_gossip_in_timespan(self, timespan: GossipTimestampFilter) \
|
|
-> List[GossipForwardingMessage]:
|
|
"""Return a list of gossip messages matching the requested timespan."""
|
|
forwarding_gossip = []
|
|
with self.lock:
|
|
chans = self._channels.copy()
|
|
policies = self._policies.copy()
|
|
nodes = self._nodes.copy()
|
|
|
|
for short_id, chan in chans.items():
|
|
# fetching the timestamp from the channel update (according to BOLT-07)
|
|
chan_up_n1 = policies.get((chan.node1_id, short_id))
|
|
chan_up_n2 = policies.get((chan.node2_id, short_id))
|
|
updates = []
|
|
for policy in [chan_up_n1, chan_up_n2]:
|
|
if policy and policy.raw and timespan.in_range(policy.timestamp):
|
|
if policy.message_flags & 0b10 == 0: # check that its not "dont_forward"
|
|
updates.append(GossipForwardingMessage(
|
|
msg=policy.raw,
|
|
timestamp=policy.timestamp))
|
|
if not updates or chan.raw is None:
|
|
continue
|
|
chan_ann_ts = min(update.timestamp for update in updates)
|
|
channel_announcement = GossipForwardingMessage(msg=chan.raw, timestamp=chan_ann_ts)
|
|
forwarding_gossip.extend([channel_announcement] + updates)
|
|
|
|
for node_ann in nodes.values():
|
|
if timespan.in_range(node_ann.timestamp) and node_ann.raw:
|
|
forwarding_gossip.append(GossipForwardingMessage(
|
|
msg=node_ann.raw,
|
|
timestamp=node_ann.timestamp))
|
|
return forwarding_gossip
|
|
|
|
def get_channels_in_range(self, first_blocknum: int, number_of_blocks: int) -> List[ShortChannelID]:
|
|
with self.lock:
|
|
channels = self._channels.copy()
|
|
scids: List[ShortChannelID] = []
|
|
for scid in channels:
|
|
if first_blocknum <= scid.block_height < first_blocknum + number_of_blocks:
|
|
scids.append(scid)
|
|
scids.sort()
|
|
return scids
|
|
|
|
def get_gossip_for_scid_request(self, scid: ShortChannelID) -> List[bytes]:
|
|
requested_gossip = []
|
|
|
|
chan_ann = self._channels.get(scid)
|
|
if not chan_ann or not chan_ann.raw:
|
|
return []
|
|
chan_up1 = self._policies.get((chan_ann.node1_id, scid))
|
|
chan_up2 = self._policies.get((chan_ann.node2_id, scid))
|
|
node_ann1 = self._nodes.get(chan_ann.node1_id)
|
|
node_ann2 = self._nodes.get(chan_ann.node2_id)
|
|
|
|
for msg in [chan_ann, chan_up1, chan_up2, node_ann1, node_ann2]:
|
|
if msg and msg.raw:
|
|
requested_gossip.append(msg.raw)
|
|
return requested_gossip
|
|
|
|
def to_dict(self) -> dict:
|
|
""" Generates a graph representation in terms of a dictionary.
|
|
|
|
The dictionary contains only native python types and can be encoded
|
|
to json.
|
|
"""
|
|
with self.lock:
|
|
graph = {'nodes': [], 'channels': []}
|
|
|
|
# gather nodes
|
|
for pk, nodeinfo in self._nodes.items():
|
|
# use _asdict() to convert NamedTuples to json encodable dicts
|
|
graph['nodes'].append(
|
|
nodeinfo._asdict(),
|
|
)
|
|
graph['nodes'][-1]['addresses'] = [
|
|
{'host': str(addr.host), 'port': addr.port, 'timestamp': ts}
|
|
for addr, ts in self._addresses[pk].items()
|
|
]
|
|
|
|
# gather channels
|
|
for cid, channelinfo in self._channels.items():
|
|
graph['channels'].append(
|
|
channelinfo._asdict(),
|
|
)
|
|
policy1 = self._policies.get(
|
|
(channelinfo.node1_id, channelinfo.short_channel_id))
|
|
policy2 = self._policies.get(
|
|
(channelinfo.node2_id, channelinfo.short_channel_id))
|
|
graph['channels'][-1]['policy1'] = policy1._asdict() if policy1 else None
|
|
graph['channels'][-1]['policy2'] = policy2._asdict() if policy2 else None
|
|
|
|
# need to use json_normalize otherwise json encoding in rpc server fails
|
|
graph = json_normalize(graph)
|
|
return graph
|