1
0

sqlite in lnrouter

This commit is contained in:
Janus
2019-02-01 20:59:59 +01:00
committed by ThomasV
parent d94e40d2be
commit dd7c4b3bab
8 changed files with 398 additions and 396 deletions

View File

@@ -23,6 +23,8 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import datetime
import random
import queue
import os
import json
@@ -33,6 +35,14 @@ import binascii
import base64
import asyncio
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from sqlalchemy.orm import scoped_session
from . import constants
from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
from .storage import JsonDB
@@ -41,112 +51,113 @@ from .crypto import sha256d
from . import ecc
from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH,
NotFoundChanAnnouncementForUpdate)
from .lnmsg import encode_msg
if TYPE_CHECKING:
from .lnchannel import Channel
from .network import Network
class UnknownEvenFeatureBits(Exception): pass
class NoChannelPolicy(Exception):
def __init__(self, short_channel_id: bytes):
super().__init__(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}')
def validate_features(features : int):
enabled_features = list_enabled_bits(features)
for fbit in enabled_features:
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
raise UnknownEvenFeatureBits()
Base = declarative_base()
session_factory = sessionmaker()
DBSession = scoped_session(session_factory)
engine = None
class ChannelInfo(PrintError):
FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0
def __init__(self, channel_announcement_payload):
self.features_len = channel_announcement_payload['len']
self.features = channel_announcement_payload['features']
enabled_features = list_enabled_bits(int.from_bytes(self.features, "big"))
for fbit in enabled_features:
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
raise UnknownEvenFeatureBits()
class ChannelInfoInDB(Base):
__tablename__ = 'channel_info'
short_channel_id = Column(String(64), primary_key=True)
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
capacity_sat = Column(Integer)
msg_payload_hex = Column(String(1024), nullable=False)
trusted = Column(Boolean, nullable=False)
self.channel_id = channel_announcement_payload['short_channel_id']
self.node_id_1 = channel_announcement_payload['node_id_1']
self.node_id_2 = channel_announcement_payload['node_id_2']
assert type(self.node_id_1) is bytes
assert type(self.node_id_2) is bytes
assert list(sorted([self.node_id_1, self.node_id_2])) == [self.node_id_1, self.node_id_2]
@staticmethod
def from_msg(channel_announcement_payload):
features = int.from_bytes(channel_announcement_payload['features'], 'big')
validate_features(features)
self.bitcoin_key_1 = channel_announcement_payload['bitcoin_key_1']
self.bitcoin_key_2 = channel_announcement_payload['bitcoin_key_2']
channel_id = channel_announcement_payload['short_channel_id'].hex()
node_id_1 = channel_announcement_payload['node_id_1'].hex()
node_id_2 = channel_announcement_payload['node_id_2'].hex()
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
# this field does not get persisted
self.msg_payload = channel_announcement_payload
msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex()
self.capacity_sat = None
self.policy_node1 = None
self.policy_node2 = None
capacity_sat = None
def to_json(self) -> dict:
d = {}
d['short_channel_id'] = bh2u(self.channel_id)
d['node_id_1'] = bh2u(self.node_id_1)
d['node_id_2'] = bh2u(self.node_id_2)
d['len'] = bh2u(self.features_len)
d['features'] = bh2u(self.features)
d['bitcoin_key_1'] = bh2u(self.bitcoin_key_1)
d['bitcoin_key_2'] = bh2u(self.bitcoin_key_2)
d['policy_node1'] = self.policy_node1
d['policy_node2'] = self.policy_node2
d['capacity_sat'] = self.capacity_sat
return d
return ChannelInfoInDB(short_channel_id = channel_id, node1_id = node_id_1,
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
trusted = False)
@classmethod
def from_json(cls, d: dict):
d2 = {}
d2['short_channel_id'] = bfh(d['short_channel_id'])
d2['node_id_1'] = bfh(d['node_id_1'])
d2['node_id_2'] = bfh(d['node_id_2'])
d2['len'] = bfh(d['len'])
d2['features'] = bfh(d['features'])
d2['bitcoin_key_1'] = bfh(d['bitcoin_key_1'])
d2['bitcoin_key_2'] = bfh(d['bitcoin_key_2'])
ci = ChannelInfo(d2)
ci.capacity_sat = d['capacity_sat']
ci.policy_node1 = ChannelInfoDirectedPolicy.from_json(d['policy_node1'])
ci.policy_node2 = ChannelInfoDirectedPolicy.from_json(d['policy_node2'])
return ci
@property
def msg_payload(self):
return bytes.fromhex(self.msg_payload_hex)
def set_capacity(self, capacity):
self.capacity_sat = capacity
def on_channel_update(self, msg_payload, trusted=False):
assert self.channel_id == msg_payload['short_channel_id']
flags = int.from_bytes(msg_payload['channel_flags'], 'big')
direction = flags & ChannelInfoDirectedPolicy.FLAG_DIRECTION
new_policy = ChannelInfoDirectedPolicy(msg_payload)
def on_channel_update(self, msg: dict, trusted=False):
assert self.short_channel_id == msg['short_channel_id'].hex()
flags = int.from_bytes(msg['channel_flags'], 'big')
direction = flags & FLAG_DIRECTION
if direction == 0:
old_policy = self.policy_node1
node_id = self.node_id_1
node_id = self.node1_id
else:
old_policy = self.policy_node2
node_id = self.node_id_2
if old_policy and old_policy.timestamp >= new_policy.timestamp:
node_id = self.node2_id
new_policy = Policy.from_msg(msg, node_id, self.short_channel_id)
old_policy = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node=node_id).one_or_none()
if not old_policy:
DBSession.add(new_policy)
return
if old_policy.timestamp >= new_policy.timestamp:
return # ignore
if not trusted and not verify_sig_for_channel_update(msg_payload, node_id):
if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)):
return # ignore
# save new policy
if direction == 0:
self.policy_node1 = new_policy
else:
self.policy_node2 = new_policy
old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta
old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat
old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat
old_policy.fee_base_msat = new_policy.fee_base_msat
old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths
old_policy.channel_flags = new_policy.channel_flags
old_policy.timestamp = new_policy.timestamp
def get_policy_for_node(self, node_id: bytes) -> Optional['ChannelInfoDirectedPolicy']:
if node_id == self.node_id_1:
return self.policy_node1
elif node_id == self.node_id_2:
return self.policy_node2
else:
raise Exception('node_id {} not in channel {}'.format(node_id, self.channel_id))
def get_policy_for_node(self, node) -> Optional['Policy']:
"""
raises when initiator/non-initiator both unequal node
"""
if node.hex() not in (self.node1_id, self.node2_id):
raise Exception("the given node is not a party in this channel")
n1 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node1_id).one_or_none()
if n1:
return n1
n2 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none()
return n2
class Policy(Base):
__tablename__ = 'policy'
start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True)
cltv_expiry_delta = Column(Integer, nullable=False)
htlc_minimum_msat = Column(Integer, nullable=False)
htlc_maximum_msat = Column(Integer)
fee_base_msat = Column(Integer, nullable=False)
fee_proportional_millionths = Column(Integer, nullable=False)
channel_flags = Column(Integer, nullable=False)
timestamp = Column(DateTime, nullable=False)
class ChannelInfoDirectedPolicy:
FLAG_DIRECTION = 1 << 0
FLAG_DISABLE = 1 << 1
def __init__(self, channel_update_payload):
@staticmethod
def from_msg(channel_update_payload, start_node, short_channel_id):
cltv_expiry_delta = channel_update_payload['cltv_expiry_delta']
htlc_minimum_msat = channel_update_payload['htlc_minimum_msat']
fee_base_msat = channel_update_payload['fee_base_msat']
@@ -155,61 +166,52 @@ class ChannelInfoDirectedPolicy:
timestamp = channel_update_payload['timestamp']
htlc_maximum_msat = channel_update_payload.get('htlc_maximum_msat') # optional
self.cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
self.htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
self.htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None
self.fee_base_msat = int.from_bytes(fee_base_msat, "big")
self.fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
self.channel_flags = int.from_bytes(channel_flags, "big")
self.timestamp = int.from_bytes(timestamp, "big")
cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None
fee_base_msat = int.from_bytes(fee_base_msat, "big")
fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
channel_flags = int.from_bytes(channel_flags, "big")
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(timestamp, "big"))
self.disabled = self.channel_flags & self.FLAG_DISABLE
return Policy(start_node=start_node,
short_channel_id=short_channel_id,
cltv_expiry_delta=cltv_expiry_delta,
htlc_minimum_msat=htlc_minimum_msat,
fee_base_msat=fee_base_msat,
fee_proportional_millionths=fee_proportional_millionths,
channel_flags=channel_flags,
timestamp=timestamp,
htlc_maximum_msat=htlc_maximum_msat)
def to_json(self) -> dict:
d = {}
d['cltv_expiry_delta'] = self.cltv_expiry_delta
d['htlc_minimum_msat'] = self.htlc_minimum_msat
d['fee_base_msat'] = self.fee_base_msat
d['fee_proportional_millionths'] = self.fee_proportional_millionths
d['channel_flags'] = self.channel_flags
d['timestamp'] = self.timestamp
if self.htlc_maximum_msat:
d['htlc_maximum_msat'] = self.htlc_maximum_msat
return d
def is_disabled(self):
return self.channel_flags & FLAG_DISABLE
@classmethod
def from_json(cls, d: dict):
if d is None: return None
d2 = {}
d2['cltv_expiry_delta'] = d['cltv_expiry_delta'].to_bytes(2, "big")
d2['htlc_minimum_msat'] = d['htlc_minimum_msat'].to_bytes(8, "big")
d2['htlc_maximum_msat'] = d['htlc_maximum_msat'].to_bytes(8, "big") if d.get('htlc_maximum_msat') else None
d2['fee_base_msat'] = d['fee_base_msat'].to_bytes(4, "big")
d2['fee_proportional_millionths'] = d['fee_proportional_millionths'].to_bytes(4, "big")
d2['channel_flags'] = d['channel_flags'].to_bytes(1, "big")
d2['timestamp'] = d['timestamp'].to_bytes(4, "big")
return ChannelInfoDirectedPolicy(d2)
class NodeInfoInDB(Base):
__tablename__ = 'node_info'
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
features = Column(Integer, nullable=False)
timestamp = Column(Integer, nullable=False)
alias = Column(String(64), nullable=False)
def get_addresses(self):
return DBSession.query(AddressInDB).join(NodeInfoInDB).filter_by(node_id = self.node_id).all()
class NodeInfo(PrintError):
def __init__(self, node_announcement_payload, addresses_already_parsed=False):
self.pubkey = node_announcement_payload['node_id']
self.features_len = node_announcement_payload['flen']
self.features = node_announcement_payload['features']
enabled_features = list_enabled_bits(int.from_bytes(self.features, "big"))
for fbit in enabled_features:
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
raise UnknownEvenFeatureBits()
@staticmethod
def from_msg(node_announcement_payload, addresses_already_parsed=False):
node_id = node_announcement_payload['node_id'].hex()
features = int.from_bytes(node_announcement_payload['features'], "big")
validate_features(features)
if not addresses_already_parsed:
self.addresses = self.parse_addresses_field(node_announcement_payload['addresses'])
addresses = NodeInfoInDB.parse_addresses_field(node_announcement_payload['addresses'])
else:
self.addresses = node_announcement_payload['addresses']
self.alias = node_announcement_payload['alias'].rstrip(b'\x00')
self.timestamp = int.from_bytes(node_announcement_payload['timestamp'], "big")
addresses = node_announcement_payload['addresses']
alias = node_announcement_payload['alias'].rstrip(b'\x00').hex()
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big"))
return NodeInfoInDB(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [AddressInDB(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses]
@classmethod
def parse_addresses_field(cls, addresses_field):
@staticmethod
def parse_addresses_field(addresses_field):
buf = addresses_field
def read(n):
nonlocal buf
@@ -248,243 +250,233 @@ class NodeInfo(PrintError):
break
return addresses
def to_json(self) -> dict:
d = {}
d['node_id'] = bh2u(self.pubkey)
d['flen'] = bh2u(self.features_len)
d['features'] = bh2u(self.features)
d['addresses'] = self.addresses
d['alias'] = bh2u(self.alias)
d['timestamp'] = self.timestamp
return d
class AddressInDB(Base):
__tablename__ = 'address'
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
host = Column(String(256), primary_key=True)
port = Column(Integer, primary_key=True)
last_connected_date = Column(DateTime(), nullable=False)
@classmethod
def from_json(cls, d: dict):
if d is None: return None
d2 = {}
d2['node_id'] = bfh(d['node_id'])
d2['flen'] = bfh(d['flen'])
d2['features'] = bfh(d['features'])
d2['addresses'] = d['addresses']
d2['alias'] = bfh(d['alias'])
d2['timestamp'] = d['timestamp'].to_bytes(4, "big")
return NodeInfo(d2, addresses_already_parsed=True)
class ChannelDB(JsonDB):
class ChannelDB:
NUM_MAX_RECENT_PEERS = 20
def __init__(self, network: 'Network'):
global engine
self.network = network
path = os.path.join(get_headers_dir(network.config), 'channel_db')
JsonDB.__init__(self, path)
self.num_nodes = 0
self.num_channels = 0
self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3')
engine = create_engine('sqlite:///' + self.path)#, echo=True)
DBSession.remove()
DBSession.configure(bind=engine, autoflush=False)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
self.lock = threading.RLock()
self._id_to_channel_info = {} # type: Dict[bytes, ChannelInfo]
self._channels_for_node = defaultdict(set) # node -> set(short_channel_id)
self.nodes = {} # node_id -> NodeInfo
self._recent_peers = []
self._last_good_address = {} # node_id -> LNPeerAddr
# (intentionally not persisted)
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], ChannelInfoDirectedPolicy]
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self)
self.load_data()
def update_counts(self):
self.num_channels = DBSession.query(ChannelInfoInDB).count()
self.num_nodes = DBSession.query(NodeInfoInDB).count()
def load_data(self):
if os.path.exists(self.path):
with open(self.path, "r", encoding='utf-8') as f:
raw = f.read()
self.data = json.loads(raw)
# channels
channel_infos = self.get('channel_infos', {})
for short_channel_id, channel_info_d in channel_infos.items():
channel_info = ChannelInfo.from_json(channel_info_d)
short_channel_id = bfh(short_channel_id)
self.add_verified_channel_info(short_channel_id, channel_info)
# nodes
node_infos = self.get('node_infos', {})
for node_id, node_info_d in node_infos.items():
node_info = NodeInfo.from_json(node_info_d)
node_id = bfh(node_id)
self.nodes[node_id] = node_info
# recent peers
recent_peers = self.get('recent_peers', {})
for host, port, pubkey in recent_peers:
peer = LNPeerAddr(str(host), int(port), bfh(pubkey))
self._recent_peers.append(peer)
# last good address
last_good_addr = self.get('last_good_address', {})
for node_id, host_and_port in last_good_addr.items():
host, port = host_and_port
self._last_good_address[bfh(node_id)] = LNPeerAddr(str(host), int(port), bfh(node_id))
def add_recent_peer(self, peer : LNPeerAddr):
addr = DBSession.query(AddressInDB).filter_by(node_id = peer.pubkey.hex()).one_or_none()
if addr is None:
addr = AddressInDB(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
else:
addr.last_connected_date = datetime.datetime.now()
DBSession.add(addr)
DBSession.commit()
def save_data(self):
with self.lock:
# channels
channel_infos = {}
for short_channel_id, channel_info in self._id_to_channel_info.items():
channel_infos[bh2u(short_channel_id)] = channel_info
self.put('channel_infos', channel_infos)
# nodes
node_infos = {}
for node_id, node_info in self.nodes.items():
node_infos[bh2u(node_id)] = node_info
self.put('node_infos', node_infos)
# recent peers
recent_peers = []
for peer in self._recent_peers:
recent_peers.append(
[str(peer.host), int(peer.port), bh2u(peer.pubkey)])
self.put('recent_peers', recent_peers)
# last good address
last_good_addr = {}
for node_id, peer in self._last_good_address.items():
last_good_addr[bh2u(node_id)] = [str(peer.host), int(peer.port)]
self.put('last_good_address', last_good_addr)
self.write()
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
unshuffled = DBSession \
.query(NodeInfoInDB) \
.filter(not_(NodeInfoInDB.node_id.in_(x.hex() for x in node_ids_bytes))) \
.limit(200) \
.all()
return random.sample(unshuffled, len(unshuffled))
def __len__(self):
# number of channels
return len(self._id_to_channel_info)
def nodes_get(self, node_id):
return self.network.run_from_another_thread(self._nodes_get(node_id))
def capacity(self):
# capacity of the network
return sum(c.capacity_sat for c in self._id_to_channel_info.values() if c.capacity_sat is not None)
async def _nodes_get(self, node_id):
return DBSession \
.query(NodeInfoInDB) \
.filter_by(node_id = node_id.hex()) \
.one_or_none()
def get_channel_info(self, channel_id: bytes) -> Optional[ChannelInfo]:
return self._id_to_channel_info.get(channel_id, None)
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
adr_db = DBSession \
.query(AddressInDB) \
.filter_by(node_id = node_id.hex()) \
.order_by(AddressInDB.last_connected_date.desc()) \
.one_or_none()
if not adr_db:
return None
return LNPeerAddr(adr_db.host, adr_db.port, bytes.fromhex(adr_db.node_id))
def get_recent_peers(self):
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in DBSession \
.query(AddressInDB) \
.select_from(NodeInfoInDB) \
.order_by(AddressInDB.last_connected_date.desc()) \
.limit(self.NUM_MAX_RECENT_PEERS)]
def get_channel_info(self, channel_id: bytes):
return self.chan_query_for_id(channel_id).one_or_none()
def get_channels_for_node(self, node_id):
"""Returns the set of channels that have node_id as one of the endpoints."""
return self._channels_for_node[node_id]
condition = or_(
ChannelInfoInDB.node1_id == node_id.hex(),
ChannelInfoInDB.node2_id == node_id.hex())
rows = DBSession.query(ChannelInfoInDB).filter(condition).all()
return [bytes.fromhex(x.short_channel_id) for x in rows]
def add_verified_channel_info(self, short_channel_id: bytes, channel_info: ChannelInfo):
with self.lock:
self._id_to_channel_info[short_channel_id] = channel_info
self._channels_for_node[channel_info.node_id_1].add(short_channel_id)
self._channels_for_node[channel_info.node_id_2].add(short_channel_id)
def add_verified_channel_info(self, short_id, capacity):
# called from lnchannelverifier
channel_info = self.get_channel_info(short_id)
channel_info.trusted = True
channel_info.capacity = capacity
DBSession.commit()
@profiler
def on_channel_announcement(self, msg_payloads, trusted=False):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
for msg in msg_payloads:
short_channel_id = msg['short_channel_id']
if DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = bh2u(short_channel_id)).count():
continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
continue
try:
channel_info = ChannelInfoInDB.from_msg(msg)
except UnknownEvenFeatureBits:
continue
channel_info.trusted = trusted
DBSession.add(channel_info)
if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
DBSession.commit()
self.network.trigger_callback('ln_status')
self.update_counts()
def get_recent_peers(self):
with self.lock:
return list(self._recent_peers)
@profiler
def on_channel_update(self, msg_payloads, trusted=False):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
channel_infos_list = DBSession.query(ChannelInfoInDB).filter(ChannelInfoInDB.short_channel_id.in_(short_channel_ids)).all()
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
for msg_payload in msg_payloads:
short_channel_id = msg_payload['short_channel_id']
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
continue
channel_info = channel_infos.get(short_channel_id)
channel_info.on_channel_update(msg_payload, trusted=trusted)
DBSession.commit()
def add_recent_peer(self, peer: LNPeerAddr):
with self.lock:
# list is ordered
if peer in self._recent_peers:
self._recent_peers.remove(peer)
self._recent_peers.insert(0, peer)
self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS]
self._last_good_address[peer.pubkey] = peer
@profiler
def on_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
addresses = DBSession.query(AddressInDB).all()
have_addr = {}
for addr in addresses:
have_addr[(addr.node_id, addr.host, addr.port)] = addr
def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
return self._last_good_address.get(node_id, None)
def on_channel_announcement(self, msg_payload, trusted=False):
short_channel_id = msg_payload['short_channel_id']
if short_channel_id in self._id_to_channel_info:
return
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
return
try:
channel_info = ChannelInfo(msg_payload)
except UnknownEvenFeatureBits:
return
if trusted:
self.add_verified_channel_info(short_channel_id, channel_info)
else:
self.ca_verifier.add_new_channel_info(channel_info)
def on_channel_update(self, msg_payload, trusted=False):
short_channel_id = msg_payload['short_channel_id']
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
return
# try finding channel in pending db
channel_info = self.ca_verifier.get_pending_channel_info(short_channel_id)
if channel_info is None:
# try finding channel in verified db
channel_info = self._id_to_channel_info.get(short_channel_id, None)
if channel_info is None:
self.print_error("could not find", short_channel_id)
raise NotFoundChanAnnouncementForUpdate()
channel_info.on_channel_update(msg_payload, trusted=trusted)
def on_node_announcement(self, msg_payload):
pubkey = msg_payload['node_id']
signature = msg_payload['signature']
h = sha256d(msg_payload['raw'][66:])
if not ecc.verify_signature(pubkey, signature, h):
return
old_node_info = self.nodes.get(pubkey, None)
try:
new_node_info = NodeInfo(msg_payload)
except UnknownEvenFeatureBits:
return
# TODO if this message is for a new node, and if we have no associated
# channels for this node, we should ignore the message and return here,
# to mitigate DOS. but race condition: the channels we have for this
# node, might be under verification in self.ca_verifier, what then?
if old_node_info and old_node_info.timestamp >= new_node_info.timestamp:
return # ignore
self.nodes[pubkey] = new_node_info
nodes = DBSession.query(NodeInfoInDB).all()
timestamps = {}
for node in nodes:
no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")]
timestamps[bfh(node.node_id)] = datetime.datetime.strptime(no_millisecs, "%Y-%m-%d %H:%M:%S")
old_addr = None
for msg_payload in msg_payloads:
pubkey = msg_payload['node_id']
signature = msg_payload['signature']
h = sha256d(msg_payload['raw'][66:])
if not ecc.verify_signature(pubkey, signature, h):
continue
try:
new_node_info, addresses = NodeInfoInDB.from_msg(msg_payload)
except UnknownEvenFeatureBits:
continue
if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp:
continue # ignore
DBSession.add(new_node_info)
for new_addr in addresses:
key = (new_addr.node_id, new_addr.host, new_addr.port)
old_addr = have_addr.get(key)
if old_addr:
# since old_addr is embedded in have_addr,
# it will still live when commmit is called
old_addr.last_connected_date = new_addr.last_connected_date
del new_addr
else:
DBSession.add(new_addr)
have_addr[key] = new_addr
# TODO if this message is for a new node, and if we have no associated
# channels for this node, we should ignore the message and return here,
# to mitigate DOS. but race condition: the channels we have for this
# node, might be under verification in self.ca_verifier, what then?
del nodes, addresses
if old_addr:
del old_addr
DBSession.commit()
self.network.trigger_callback('ln_status')
self.update_counts()
def get_routing_policy_for_channel(self, start_node_id: bytes,
short_channel_id: bytes) -> Optional[ChannelInfoDirectedPolicy]:
short_channel_id: bytes) -> Optional[bytes]:
if not start_node_id or not short_channel_id: return None
channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None:
return channel_info.get_policy_for_node(start_node_id)
return self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
if not msg: return None
return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
if not verify_sig_for_channel_update(msg_payload, start_node_id):
return # ignore
short_channel_id = msg_payload['short_channel_id']
policy = ChannelInfoDirectedPolicy(msg_payload)
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = policy
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
def remove_channel(self, short_channel_id):
try:
channel_info = self._id_to_channel_info[short_channel_id]
except KeyError:
self.print_error(f'remove_channel: cannot find channel {bh2u(short_channel_id)}')
return
self._id_to_channel_info.pop(short_channel_id, None)
for node in (channel_info.node_id_1, channel_info.node_id_2):
try:
self._channels_for_node[node].remove(short_channel_id)
except KeyError:
pass
self.chan_query_for_id(short_channel_id).delete('evaluate')
DBSession.commit()
def chan_query_for_id(self, short_channel_id) -> Query:
return DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = short_channel_id.hex())
def print_graph(self, full_ids=False):
# used for debugging.
# FIXME there is a race here - iterables could change size from another thread
def other_node_id(node_id, channel_id):
channel_info = self._id_to_channel_info[channel_id]
if node_id == channel_info.node_id_1:
other = channel_info.node_id_2
channel_info = self.get_channel_info(channel_id)
if node_id == channel_info.node1_id:
other = channel_info.node2_id
else:
other = channel_info.node_id_1
other = channel_info.node1_id
return other if full_ids else other[-4:]
self.print_msg('node: {(channel, other_node), ...}')
for node_id, short_channel_ids in list(self._channels_for_node.items()):
short_channel_ids = {(bh2u(cid), bh2u(other_node_id(node_id, cid)))
for cid in short_channel_ids}
node_id = bh2u(node_id) if full_ids else bh2u(node_id[-4:])
self.print_msg('{}: {}'.format(node_id, short_channel_ids))
self.print_msg('nodes')
for node in DBSession.query(NodeInfoInDB).all():
self.print_msg(node)
self.print_msg('channel: node1, node2, direction')
for short_channel_id, channel_info in list(self._id_to_channel_info.items()):
node1 = channel_info.node_id_1
node2 = channel_info.node_id_2
self.print_msg('channels')
for channel_info in DBSession.query(ChannelInfoInDB).all():
node1 = channel_info.node1_id
node2 = channel_info.node2_id
direction1 = channel_info.get_policy_for_node(node1) is not None
direction2 = channel_info.get_policy_for_node(node2) is not None
if direction1 and direction2:
@@ -514,8 +506,10 @@ class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
+ (amount_msat * self.fee_proportional_millionths // 1_000_000)
@classmethod
def from_channel_policy(cls, channel_policy: ChannelInfoDirectedPolicy,
def from_channel_policy(cls, channel_policy: 'Policy',
short_channel_id: bytes, end_node: bytes) -> 'RouteEdge':
assert type(short_channel_id) is bytes
assert type(end_node) is bytes
return RouteEdge(end_node,
short_channel_id,
channel_policy.fee_base_msat,
@@ -582,7 +576,7 @@ class LNPathFinder(PrintError):
channel_policy = channel_info.get_policy_for_node(start_node)
if channel_policy is None: return float('inf'), 0
if channel_policy.disabled: return float('inf'), 0
if channel_policy.is_disabled(): return float('inf'), 0
route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node)
if payment_amt_msat < channel_policy.htlc_minimum_msat:
return float('inf'), 0 # payment amount too little
@@ -611,6 +605,8 @@ class LNPathFinder(PrintError):
To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
i.e. an element reads as, "to get to node_id, travel through short_channel_id"
"""
assert type(nodeA) is bytes
assert type(nodeB) is bytes
assert type(invoice_amount_msat) is int
if my_channels is None: my_channels = []
my_channels = {chan.short_channel_id: chan for chan in my_channels}
@@ -657,9 +653,10 @@ class LNPathFinder(PrintError):
# so there are duplicates in the queue, that we discard now:
continue
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
assert type(edge_channel_id) is bytes
if edge_channel_id in self.blacklist: continue
channel_info = self.channel_db.get_channel_info(edge_channel_id)
edge_startnode = channel_info.node_id_2 if channel_info.node_id_1 == edge_endnode else channel_info.node_id_1
edge_startnode = bfh(channel_info.node2_id) if bfh(channel_info.node1_id) == edge_endnode else bfh(channel_info.node1_id)
inspect_edge()
else:
return None # no path found
@@ -682,7 +679,7 @@ class LNPathFinder(PrintError):
for node_id, short_channel_id in path:
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
if channel_policy is None:
raise Exception(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}')
raise NoChannelPolicy(short_channel_id)
route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id))
prev_node_id = node_id
return route