sqlite in lnrouter
This commit is contained in:
@@ -105,25 +105,21 @@ class ChannelsList(MyTreeView):
|
|||||||
|
|
||||||
def update_status(self):
|
def update_status(self):
|
||||||
channel_db = self.parent.network.channel_db
|
channel_db = self.parent.network.channel_db
|
||||||
num_nodes = len(channel_db.nodes)
|
|
||||||
num_channels = len(channel_db)
|
|
||||||
num_peers = len(self.parent.wallet.lnworker.peers)
|
num_peers = len(self.parent.wallet.lnworker.peers)
|
||||||
msg = _('{} peers, {} nodes, {} channels.').format(num_peers, num_nodes, num_channels)
|
msg = _('{} peers, {} nodes, {} channels.').format(num_peers, channel_db.num_nodes, channel_db.num_channels)
|
||||||
self.status.setText(msg)
|
self.status.setText(msg)
|
||||||
|
|
||||||
def statistics_dialog(self):
|
def statistics_dialog(self):
|
||||||
channel_db = self.parent.network.channel_db
|
channel_db = self.parent.network.channel_db
|
||||||
num_nodes = len(channel_db.nodes)
|
|
||||||
num_channels = len(channel_db)
|
|
||||||
capacity = self.parent.format_amount(channel_db.capacity()) + ' '+ self.parent.base_unit()
|
capacity = self.parent.format_amount(channel_db.capacity()) + ' '+ self.parent.base_unit()
|
||||||
d = WindowModalDialog(self.parent, _('Lightning Network Statistics'))
|
d = WindowModalDialog(self.parent, _('Lightning Network Statistics'))
|
||||||
d.setMinimumWidth(400)
|
d.setMinimumWidth(400)
|
||||||
vbox = QVBoxLayout(d)
|
vbox = QVBoxLayout(d)
|
||||||
h = QGridLayout()
|
h = QGridLayout()
|
||||||
h.addWidget(QLabel(_('Nodes') + ':'), 0, 0)
|
h.addWidget(QLabel(_('Nodes') + ':'), 0, 0)
|
||||||
h.addWidget(QLabel('{}'.format(num_nodes)), 0, 1)
|
h.addWidget(QLabel('{}'.format(channel_db.num_nodes)), 0, 1)
|
||||||
h.addWidget(QLabel(_('Channels') + ':'), 1, 0)
|
h.addWidget(QLabel(_('Channels') + ':'), 1, 0)
|
||||||
h.addWidget(QLabel('{}'.format(num_channels)), 1, 1)
|
h.addWidget(QLabel('{}'.format(channel_db.num_channels)), 1, 1)
|
||||||
h.addWidget(QLabel(_('Capacity') + ':'), 2, 0)
|
h.addWidget(QLabel(_('Capacity') + ':'), 2, 0)
|
||||||
h.addWidget(QLabel(capacity), 2, 1)
|
h.addWidget(QLabel(capacity), 2, 1)
|
||||||
vbox.addLayout(h)
|
vbox.addLayout(h)
|
||||||
|
|||||||
@@ -55,6 +55,10 @@ class Peer(PrintError):
|
|||||||
def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase,
|
def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase,
|
||||||
request_initial_sync=False):
|
request_initial_sync=False):
|
||||||
self.initialized = asyncio.Event()
|
self.initialized = asyncio.Event()
|
||||||
|
self.node_anns = []
|
||||||
|
self.chan_anns = []
|
||||||
|
self.chan_upds = []
|
||||||
|
self.last_chan_db_upd = time.time()
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.pubkey = pubkey
|
self.pubkey = pubkey
|
||||||
self.lnworker = lnworker
|
self.lnworker = lnworker
|
||||||
@@ -152,10 +156,6 @@ class Peer(PrintError):
|
|||||||
if channel_id not in self.funding_created: raise Exception("Got unknown funding_created")
|
if channel_id not in self.funding_created: raise Exception("Got unknown funding_created")
|
||||||
self.funding_created[channel_id].put_nowait(payload)
|
self.funding_created[channel_id].put_nowait(payload)
|
||||||
|
|
||||||
def on_node_announcement(self, payload):
|
|
||||||
self.channel_db.on_node_announcement(payload)
|
|
||||||
self.network.trigger_callback('ln_status')
|
|
||||||
|
|
||||||
def on_init(self, payload):
|
def on_init(self, payload):
|
||||||
if self.initialized.is_set():
|
if self.initialized.is_set():
|
||||||
self.print_error("ALREADY INITIALIZED BUT RECEIVED INIT")
|
self.print_error("ALREADY INITIALIZED BUT RECEIVED INIT")
|
||||||
@@ -175,20 +175,14 @@ class Peer(PrintError):
|
|||||||
self.send_message('gossip_timestamp_filter', chain_hash=constants.net.rev_genesis_bytes(), first_timestamp=first_timestamp, timestamp_range=b"\xff"*4)
|
self.send_message('gossip_timestamp_filter', chain_hash=constants.net.rev_genesis_bytes(), first_timestamp=first_timestamp, timestamp_range=b"\xff"*4)
|
||||||
self.initialized.set()
|
self.initialized.set()
|
||||||
|
|
||||||
|
def on_node_announcement(self, payload):
|
||||||
|
self.node_anns.append(payload)
|
||||||
|
|
||||||
def on_channel_update(self, payload):
|
def on_channel_update(self, payload):
|
||||||
try:
|
self.chan_upds.append(payload)
|
||||||
self.channel_db.on_channel_update(payload)
|
|
||||||
except NotFoundChanAnnouncementForUpdate:
|
|
||||||
# If it's for a direct channel with this peer, save it for later, as it might be
|
|
||||||
# for our own channel (and we might not yet know the short channel id for that)
|
|
||||||
short_channel_id = payload['short_channel_id']
|
|
||||||
self.print_error("not found channel announce for channel update in db", bh2u(short_channel_id))
|
|
||||||
self.orphan_channel_updates[short_channel_id] = payload
|
|
||||||
while len(self.orphan_channel_updates) > 10:
|
|
||||||
self.orphan_channel_updates.popitem(last=False)
|
|
||||||
|
|
||||||
def on_channel_announcement(self, payload):
|
def on_channel_announcement(self, payload):
|
||||||
self.channel_db.on_channel_announcement(payload)
|
self.chan_anns.append(payload)
|
||||||
|
|
||||||
def on_announcement_signatures(self, payload):
|
def on_announcement_signatures(self, payload):
|
||||||
channel_id = payload['channel_id']
|
channel_id = payload['channel_id']
|
||||||
@@ -230,6 +224,15 @@ class Peer(PrintError):
|
|||||||
# loop
|
# loop
|
||||||
async for msg in self.transport.read_messages():
|
async for msg in self.transport.read_messages():
|
||||||
self.process_message(msg)
|
self.process_message(msg)
|
||||||
|
await asyncio.sleep(.01)
|
||||||
|
if time.time() - self.last_chan_db_upd > 5:
|
||||||
|
self.last_chan_db_upd = time.time()
|
||||||
|
self.channel_db.on_node_announcement(self.node_anns)
|
||||||
|
self.node_anns = []
|
||||||
|
self.channel_db.on_channel_announcement(self.chan_anns)
|
||||||
|
self.chan_anns = []
|
||||||
|
self.channel_db.on_channel_update(self.chan_upds)
|
||||||
|
self.chan_upds = []
|
||||||
self.ping_if_required()
|
self.ping_if_required()
|
||||||
|
|
||||||
def close_and_cleanup(self):
|
def close_and_cleanup(self):
|
||||||
|
|||||||
@@ -23,6 +23,8 @@
|
|||||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import random
|
||||||
import queue
|
import queue
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
@@ -33,6 +35,14 @@ import binascii
|
|||||||
import base64
|
import base64
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.orm.query import Query
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.sql import not_, or_
|
||||||
|
from sqlalchemy.orm import scoped_session
|
||||||
|
|
||||||
from . import constants
|
from . import constants
|
||||||
from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
|
from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
|
||||||
from .storage import JsonDB
|
from .storage import JsonDB
|
||||||
@@ -41,112 +51,113 @@ from .crypto import sha256d
|
|||||||
from . import ecc
|
from . import ecc
|
||||||
from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH,
|
from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH,
|
||||||
NotFoundChanAnnouncementForUpdate)
|
NotFoundChanAnnouncementForUpdate)
|
||||||
|
from .lnmsg import encode_msg
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .lnchannel import Channel
|
from .lnchannel import Channel
|
||||||
from .network import Network
|
from .network import Network
|
||||||
|
|
||||||
|
|
||||||
class UnknownEvenFeatureBits(Exception): pass
|
class UnknownEvenFeatureBits(Exception): pass
|
||||||
|
class NoChannelPolicy(Exception):
|
||||||
|
def __init__(self, short_channel_id: bytes):
|
||||||
|
super().__init__(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}')
|
||||||
|
|
||||||
|
def validate_features(features : int):
|
||||||
|
enabled_features = list_enabled_bits(features)
|
||||||
|
for fbit in enabled_features:
|
||||||
|
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
|
||||||
|
raise UnknownEvenFeatureBits()
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
session_factory = sessionmaker()
|
||||||
|
DBSession = scoped_session(session_factory)
|
||||||
|
engine = None
|
||||||
|
|
||||||
class ChannelInfo(PrintError):
|
FLAG_DISABLE = 1 << 1
|
||||||
|
FLAG_DIRECTION = 1 << 0
|
||||||
|
|
||||||
def __init__(self, channel_announcement_payload):
|
class ChannelInfoInDB(Base):
|
||||||
self.features_len = channel_announcement_payload['len']
|
__tablename__ = 'channel_info'
|
||||||
self.features = channel_announcement_payload['features']
|
short_channel_id = Column(String(64), primary_key=True)
|
||||||
enabled_features = list_enabled_bits(int.from_bytes(self.features, "big"))
|
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
|
||||||
for fbit in enabled_features:
|
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
|
||||||
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
|
capacity_sat = Column(Integer)
|
||||||
raise UnknownEvenFeatureBits()
|
msg_payload_hex = Column(String(1024), nullable=False)
|
||||||
|
trusted = Column(Boolean, nullable=False)
|
||||||
|
|
||||||
self.channel_id = channel_announcement_payload['short_channel_id']
|
@staticmethod
|
||||||
self.node_id_1 = channel_announcement_payload['node_id_1']
|
def from_msg(channel_announcement_payload):
|
||||||
self.node_id_2 = channel_announcement_payload['node_id_2']
|
features = int.from_bytes(channel_announcement_payload['features'], 'big')
|
||||||
assert type(self.node_id_1) is bytes
|
validate_features(features)
|
||||||
assert type(self.node_id_2) is bytes
|
|
||||||
assert list(sorted([self.node_id_1, self.node_id_2])) == [self.node_id_1, self.node_id_2]
|
|
||||||
|
|
||||||
self.bitcoin_key_1 = channel_announcement_payload['bitcoin_key_1']
|
channel_id = channel_announcement_payload['short_channel_id'].hex()
|
||||||
self.bitcoin_key_2 = channel_announcement_payload['bitcoin_key_2']
|
node_id_1 = channel_announcement_payload['node_id_1'].hex()
|
||||||
|
node_id_2 = channel_announcement_payload['node_id_2'].hex()
|
||||||
|
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
|
||||||
|
|
||||||
# this field does not get persisted
|
msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex()
|
||||||
self.msg_payload = channel_announcement_payload
|
|
||||||
|
|
||||||
self.capacity_sat = None
|
capacity_sat = None
|
||||||
self.policy_node1 = None
|
|
||||||
self.policy_node2 = None
|
|
||||||
|
|
||||||
def to_json(self) -> dict:
|
return ChannelInfoInDB(short_channel_id = channel_id, node1_id = node_id_1,
|
||||||
d = {}
|
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
|
||||||
d['short_channel_id'] = bh2u(self.channel_id)
|
trusted = False)
|
||||||
d['node_id_1'] = bh2u(self.node_id_1)
|
|
||||||
d['node_id_2'] = bh2u(self.node_id_2)
|
|
||||||
d['len'] = bh2u(self.features_len)
|
|
||||||
d['features'] = bh2u(self.features)
|
|
||||||
d['bitcoin_key_1'] = bh2u(self.bitcoin_key_1)
|
|
||||||
d['bitcoin_key_2'] = bh2u(self.bitcoin_key_2)
|
|
||||||
d['policy_node1'] = self.policy_node1
|
|
||||||
d['policy_node2'] = self.policy_node2
|
|
||||||
d['capacity_sat'] = self.capacity_sat
|
|
||||||
return d
|
|
||||||
|
|
||||||
@classmethod
|
@property
|
||||||
def from_json(cls, d: dict):
|
def msg_payload(self):
|
||||||
d2 = {}
|
return bytes.fromhex(self.msg_payload_hex)
|
||||||
d2['short_channel_id'] = bfh(d['short_channel_id'])
|
|
||||||
d2['node_id_1'] = bfh(d['node_id_1'])
|
|
||||||
d2['node_id_2'] = bfh(d['node_id_2'])
|
|
||||||
d2['len'] = bfh(d['len'])
|
|
||||||
d2['features'] = bfh(d['features'])
|
|
||||||
d2['bitcoin_key_1'] = bfh(d['bitcoin_key_1'])
|
|
||||||
d2['bitcoin_key_2'] = bfh(d['bitcoin_key_2'])
|
|
||||||
ci = ChannelInfo(d2)
|
|
||||||
ci.capacity_sat = d['capacity_sat']
|
|
||||||
ci.policy_node1 = ChannelInfoDirectedPolicy.from_json(d['policy_node1'])
|
|
||||||
ci.policy_node2 = ChannelInfoDirectedPolicy.from_json(d['policy_node2'])
|
|
||||||
return ci
|
|
||||||
|
|
||||||
def set_capacity(self, capacity):
|
def on_channel_update(self, msg: dict, trusted=False):
|
||||||
self.capacity_sat = capacity
|
assert self.short_channel_id == msg['short_channel_id'].hex()
|
||||||
|
flags = int.from_bytes(msg['channel_flags'], 'big')
|
||||||
def on_channel_update(self, msg_payload, trusted=False):
|
direction = flags & FLAG_DIRECTION
|
||||||
assert self.channel_id == msg_payload['short_channel_id']
|
|
||||||
flags = int.from_bytes(msg_payload['channel_flags'], 'big')
|
|
||||||
direction = flags & ChannelInfoDirectedPolicy.FLAG_DIRECTION
|
|
||||||
new_policy = ChannelInfoDirectedPolicy(msg_payload)
|
|
||||||
if direction == 0:
|
if direction == 0:
|
||||||
old_policy = self.policy_node1
|
node_id = self.node1_id
|
||||||
node_id = self.node_id_1
|
|
||||||
else:
|
else:
|
||||||
old_policy = self.policy_node2
|
node_id = self.node2_id
|
||||||
node_id = self.node_id_2
|
new_policy = Policy.from_msg(msg, node_id, self.short_channel_id)
|
||||||
if old_policy and old_policy.timestamp >= new_policy.timestamp:
|
old_policy = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node=node_id).one_or_none()
|
||||||
|
if not old_policy:
|
||||||
|
DBSession.add(new_policy)
|
||||||
|
return
|
||||||
|
if old_policy.timestamp >= new_policy.timestamp:
|
||||||
return # ignore
|
return # ignore
|
||||||
if not trusted and not verify_sig_for_channel_update(msg_payload, node_id):
|
if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)):
|
||||||
return # ignore
|
return # ignore
|
||||||
# save new policy
|
old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta
|
||||||
if direction == 0:
|
old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat
|
||||||
self.policy_node1 = new_policy
|
old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat
|
||||||
else:
|
old_policy.fee_base_msat = new_policy.fee_base_msat
|
||||||
self.policy_node2 = new_policy
|
old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths
|
||||||
|
old_policy.channel_flags = new_policy.channel_flags
|
||||||
|
old_policy.timestamp = new_policy.timestamp
|
||||||
|
|
||||||
def get_policy_for_node(self, node_id: bytes) -> Optional['ChannelInfoDirectedPolicy']:
|
def get_policy_for_node(self, node) -> Optional['Policy']:
|
||||||
if node_id == self.node_id_1:
|
"""
|
||||||
return self.policy_node1
|
raises when initiator/non-initiator both unequal node
|
||||||
elif node_id == self.node_id_2:
|
"""
|
||||||
return self.policy_node2
|
if node.hex() not in (self.node1_id, self.node2_id):
|
||||||
else:
|
raise Exception("the given node is not a party in this channel")
|
||||||
raise Exception('node_id {} not in channel {}'.format(node_id, self.channel_id))
|
n1 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node1_id).one_or_none()
|
||||||
|
if n1:
|
||||||
|
return n1
|
||||||
|
n2 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none()
|
||||||
|
return n2
|
||||||
|
|
||||||
|
class Policy(Base):
|
||||||
|
__tablename__ = 'policy'
|
||||||
|
start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
|
||||||
|
short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True)
|
||||||
|
cltv_expiry_delta = Column(Integer, nullable=False)
|
||||||
|
htlc_minimum_msat = Column(Integer, nullable=False)
|
||||||
|
htlc_maximum_msat = Column(Integer)
|
||||||
|
fee_base_msat = Column(Integer, nullable=False)
|
||||||
|
fee_proportional_millionths = Column(Integer, nullable=False)
|
||||||
|
channel_flags = Column(Integer, nullable=False)
|
||||||
|
timestamp = Column(DateTime, nullable=False)
|
||||||
|
|
||||||
class ChannelInfoDirectedPolicy:
|
@staticmethod
|
||||||
|
def from_msg(channel_update_payload, start_node, short_channel_id):
|
||||||
FLAG_DIRECTION = 1 << 0
|
|
||||||
FLAG_DISABLE = 1 << 1
|
|
||||||
|
|
||||||
def __init__(self, channel_update_payload):
|
|
||||||
cltv_expiry_delta = channel_update_payload['cltv_expiry_delta']
|
cltv_expiry_delta = channel_update_payload['cltv_expiry_delta']
|
||||||
htlc_minimum_msat = channel_update_payload['htlc_minimum_msat']
|
htlc_minimum_msat = channel_update_payload['htlc_minimum_msat']
|
||||||
fee_base_msat = channel_update_payload['fee_base_msat']
|
fee_base_msat = channel_update_payload['fee_base_msat']
|
||||||
@@ -155,61 +166,52 @@ class ChannelInfoDirectedPolicy:
|
|||||||
timestamp = channel_update_payload['timestamp']
|
timestamp = channel_update_payload['timestamp']
|
||||||
htlc_maximum_msat = channel_update_payload.get('htlc_maximum_msat') # optional
|
htlc_maximum_msat = channel_update_payload.get('htlc_maximum_msat') # optional
|
||||||
|
|
||||||
self.cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
|
cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
|
||||||
self.htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
|
htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
|
||||||
self.htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None
|
htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None
|
||||||
self.fee_base_msat = int.from_bytes(fee_base_msat, "big")
|
fee_base_msat = int.from_bytes(fee_base_msat, "big")
|
||||||
self.fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
|
fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
|
||||||
self.channel_flags = int.from_bytes(channel_flags, "big")
|
channel_flags = int.from_bytes(channel_flags, "big")
|
||||||
self.timestamp = int.from_bytes(timestamp, "big")
|
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(timestamp, "big"))
|
||||||
|
|
||||||
self.disabled = self.channel_flags & self.FLAG_DISABLE
|
return Policy(start_node=start_node,
|
||||||
|
short_channel_id=short_channel_id,
|
||||||
|
cltv_expiry_delta=cltv_expiry_delta,
|
||||||
|
htlc_minimum_msat=htlc_minimum_msat,
|
||||||
|
fee_base_msat=fee_base_msat,
|
||||||
|
fee_proportional_millionths=fee_proportional_millionths,
|
||||||
|
channel_flags=channel_flags,
|
||||||
|
timestamp=timestamp,
|
||||||
|
htlc_maximum_msat=htlc_maximum_msat)
|
||||||
|
|
||||||
def to_json(self) -> dict:
|
def is_disabled(self):
|
||||||
d = {}
|
return self.channel_flags & FLAG_DISABLE
|
||||||
d['cltv_expiry_delta'] = self.cltv_expiry_delta
|
|
||||||
d['htlc_minimum_msat'] = self.htlc_minimum_msat
|
|
||||||
d['fee_base_msat'] = self.fee_base_msat
|
|
||||||
d['fee_proportional_millionths'] = self.fee_proportional_millionths
|
|
||||||
d['channel_flags'] = self.channel_flags
|
|
||||||
d['timestamp'] = self.timestamp
|
|
||||||
if self.htlc_maximum_msat:
|
|
||||||
d['htlc_maximum_msat'] = self.htlc_maximum_msat
|
|
||||||
return d
|
|
||||||
|
|
||||||
@classmethod
|
class NodeInfoInDB(Base):
|
||||||
def from_json(cls, d: dict):
|
__tablename__ = 'node_info'
|
||||||
if d is None: return None
|
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
|
||||||
d2 = {}
|
features = Column(Integer, nullable=False)
|
||||||
d2['cltv_expiry_delta'] = d['cltv_expiry_delta'].to_bytes(2, "big")
|
timestamp = Column(Integer, nullable=False)
|
||||||
d2['htlc_minimum_msat'] = d['htlc_minimum_msat'].to_bytes(8, "big")
|
alias = Column(String(64), nullable=False)
|
||||||
d2['htlc_maximum_msat'] = d['htlc_maximum_msat'].to_bytes(8, "big") if d.get('htlc_maximum_msat') else None
|
|
||||||
d2['fee_base_msat'] = d['fee_base_msat'].to_bytes(4, "big")
|
|
||||||
d2['fee_proportional_millionths'] = d['fee_proportional_millionths'].to_bytes(4, "big")
|
|
||||||
d2['channel_flags'] = d['channel_flags'].to_bytes(1, "big")
|
|
||||||
d2['timestamp'] = d['timestamp'].to_bytes(4, "big")
|
|
||||||
return ChannelInfoDirectedPolicy(d2)
|
|
||||||
|
|
||||||
|
def get_addresses(self):
|
||||||
|
return DBSession.query(AddressInDB).join(NodeInfoInDB).filter_by(node_id = self.node_id).all()
|
||||||
|
|
||||||
class NodeInfo(PrintError):
|
@staticmethod
|
||||||
|
def from_msg(node_announcement_payload, addresses_already_parsed=False):
|
||||||
def __init__(self, node_announcement_payload, addresses_already_parsed=False):
|
node_id = node_announcement_payload['node_id'].hex()
|
||||||
self.pubkey = node_announcement_payload['node_id']
|
features = int.from_bytes(node_announcement_payload['features'], "big")
|
||||||
self.features_len = node_announcement_payload['flen']
|
validate_features(features)
|
||||||
self.features = node_announcement_payload['features']
|
|
||||||
enabled_features = list_enabled_bits(int.from_bytes(self.features, "big"))
|
|
||||||
for fbit in enabled_features:
|
|
||||||
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
|
|
||||||
raise UnknownEvenFeatureBits()
|
|
||||||
if not addresses_already_parsed:
|
if not addresses_already_parsed:
|
||||||
self.addresses = self.parse_addresses_field(node_announcement_payload['addresses'])
|
addresses = NodeInfoInDB.parse_addresses_field(node_announcement_payload['addresses'])
|
||||||
else:
|
else:
|
||||||
self.addresses = node_announcement_payload['addresses']
|
addresses = node_announcement_payload['addresses']
|
||||||
self.alias = node_announcement_payload['alias'].rstrip(b'\x00')
|
alias = node_announcement_payload['alias'].rstrip(b'\x00').hex()
|
||||||
self.timestamp = int.from_bytes(node_announcement_payload['timestamp'], "big")
|
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big"))
|
||||||
|
return NodeInfoInDB(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [AddressInDB(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses]
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def parse_addresses_field(cls, addresses_field):
|
def parse_addresses_field(addresses_field):
|
||||||
buf = addresses_field
|
buf = addresses_field
|
||||||
def read(n):
|
def read(n):
|
||||||
nonlocal buf
|
nonlocal buf
|
||||||
@@ -248,243 +250,233 @@ class NodeInfo(PrintError):
|
|||||||
break
|
break
|
||||||
return addresses
|
return addresses
|
||||||
|
|
||||||
def to_json(self) -> dict:
|
class AddressInDB(Base):
|
||||||
d = {}
|
__tablename__ = 'address'
|
||||||
d['node_id'] = bh2u(self.pubkey)
|
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
|
||||||
d['flen'] = bh2u(self.features_len)
|
host = Column(String(256), primary_key=True)
|
||||||
d['features'] = bh2u(self.features)
|
port = Column(Integer, primary_key=True)
|
||||||
d['addresses'] = self.addresses
|
last_connected_date = Column(DateTime(), nullable=False)
|
||||||
d['alias'] = bh2u(self.alias)
|
|
||||||
d['timestamp'] = self.timestamp
|
|
||||||
return d
|
|
||||||
|
|
||||||
@classmethod
|
class ChannelDB:
|
||||||
def from_json(cls, d: dict):
|
|
||||||
if d is None: return None
|
|
||||||
d2 = {}
|
|
||||||
d2['node_id'] = bfh(d['node_id'])
|
|
||||||
d2['flen'] = bfh(d['flen'])
|
|
||||||
d2['features'] = bfh(d['features'])
|
|
||||||
d2['addresses'] = d['addresses']
|
|
||||||
d2['alias'] = bfh(d['alias'])
|
|
||||||
d2['timestamp'] = d['timestamp'].to_bytes(4, "big")
|
|
||||||
return NodeInfo(d2, addresses_already_parsed=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelDB(JsonDB):
|
|
||||||
|
|
||||||
NUM_MAX_RECENT_PEERS = 20
|
NUM_MAX_RECENT_PEERS = 20
|
||||||
|
|
||||||
def __init__(self, network: 'Network'):
|
def __init__(self, network: 'Network'):
|
||||||
|
global engine
|
||||||
self.network = network
|
self.network = network
|
||||||
|
|
||||||
path = os.path.join(get_headers_dir(network.config), 'channel_db')
|
self.num_nodes = 0
|
||||||
JsonDB.__init__(self, path)
|
self.num_channels = 0
|
||||||
|
|
||||||
|
self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3')
|
||||||
|
engine = create_engine('sqlite:///' + self.path)#, echo=True)
|
||||||
|
DBSession.remove()
|
||||||
|
DBSession.configure(bind=engine, autoflush=False)
|
||||||
|
|
||||||
|
Base.metadata.drop_all(engine)
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
|
||||||
self.lock = threading.RLock()
|
self.lock = threading.RLock()
|
||||||
self._id_to_channel_info = {} # type: Dict[bytes, ChannelInfo]
|
|
||||||
self._channels_for_node = defaultdict(set) # node -> set(short_channel_id)
|
|
||||||
self.nodes = {} # node_id -> NodeInfo
|
|
||||||
self._recent_peers = []
|
|
||||||
self._last_good_address = {} # node_id -> LNPeerAddr
|
|
||||||
|
|
||||||
# (intentionally not persisted)
|
# (intentionally not persisted)
|
||||||
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], ChannelInfoDirectedPolicy]
|
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
|
||||||
|
|
||||||
self.ca_verifier = LNChannelVerifier(network, self)
|
self.ca_verifier = LNChannelVerifier(network, self)
|
||||||
|
|
||||||
self.load_data()
|
def update_counts(self):
|
||||||
|
self.num_channels = DBSession.query(ChannelInfoInDB).count()
|
||||||
|
self.num_nodes = DBSession.query(NodeInfoInDB).count()
|
||||||
|
|
||||||
def load_data(self):
|
def add_recent_peer(self, peer : LNPeerAddr):
|
||||||
if os.path.exists(self.path):
|
addr = DBSession.query(AddressInDB).filter_by(node_id = peer.pubkey.hex()).one_or_none()
|
||||||
with open(self.path, "r", encoding='utf-8') as f:
|
if addr is None:
|
||||||
raw = f.read()
|
addr = AddressInDB(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
|
||||||
self.data = json.loads(raw)
|
else:
|
||||||
# channels
|
addr.last_connected_date = datetime.datetime.now()
|
||||||
channel_infos = self.get('channel_infos', {})
|
DBSession.add(addr)
|
||||||
for short_channel_id, channel_info_d in channel_infos.items():
|
DBSession.commit()
|
||||||
channel_info = ChannelInfo.from_json(channel_info_d)
|
|
||||||
short_channel_id = bfh(short_channel_id)
|
|
||||||
self.add_verified_channel_info(short_channel_id, channel_info)
|
|
||||||
# nodes
|
|
||||||
node_infos = self.get('node_infos', {})
|
|
||||||
for node_id, node_info_d in node_infos.items():
|
|
||||||
node_info = NodeInfo.from_json(node_info_d)
|
|
||||||
node_id = bfh(node_id)
|
|
||||||
self.nodes[node_id] = node_info
|
|
||||||
# recent peers
|
|
||||||
recent_peers = self.get('recent_peers', {})
|
|
||||||
for host, port, pubkey in recent_peers:
|
|
||||||
peer = LNPeerAddr(str(host), int(port), bfh(pubkey))
|
|
||||||
self._recent_peers.append(peer)
|
|
||||||
# last good address
|
|
||||||
last_good_addr = self.get('last_good_address', {})
|
|
||||||
for node_id, host_and_port in last_good_addr.items():
|
|
||||||
host, port = host_and_port
|
|
||||||
self._last_good_address[bfh(node_id)] = LNPeerAddr(str(host), int(port), bfh(node_id))
|
|
||||||
|
|
||||||
def save_data(self):
|
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
|
||||||
with self.lock:
|
unshuffled = DBSession \
|
||||||
# channels
|
.query(NodeInfoInDB) \
|
||||||
channel_infos = {}
|
.filter(not_(NodeInfoInDB.node_id.in_(x.hex() for x in node_ids_bytes))) \
|
||||||
for short_channel_id, channel_info in self._id_to_channel_info.items():
|
.limit(200) \
|
||||||
channel_infos[bh2u(short_channel_id)] = channel_info
|
.all()
|
||||||
self.put('channel_infos', channel_infos)
|
return random.sample(unshuffled, len(unshuffled))
|
||||||
# nodes
|
|
||||||
node_infos = {}
|
|
||||||
for node_id, node_info in self.nodes.items():
|
|
||||||
node_infos[bh2u(node_id)] = node_info
|
|
||||||
self.put('node_infos', node_infos)
|
|
||||||
# recent peers
|
|
||||||
recent_peers = []
|
|
||||||
for peer in self._recent_peers:
|
|
||||||
recent_peers.append(
|
|
||||||
[str(peer.host), int(peer.port), bh2u(peer.pubkey)])
|
|
||||||
self.put('recent_peers', recent_peers)
|
|
||||||
# last good address
|
|
||||||
last_good_addr = {}
|
|
||||||
for node_id, peer in self._last_good_address.items():
|
|
||||||
last_good_addr[bh2u(node_id)] = [str(peer.host), int(peer.port)]
|
|
||||||
self.put('last_good_address', last_good_addr)
|
|
||||||
self.write()
|
|
||||||
|
|
||||||
def __len__(self):
|
def nodes_get(self, node_id):
|
||||||
# number of channels
|
return self.network.run_from_another_thread(self._nodes_get(node_id))
|
||||||
return len(self._id_to_channel_info)
|
|
||||||
|
|
||||||
def capacity(self):
|
async def _nodes_get(self, node_id):
|
||||||
# capacity of the network
|
return DBSession \
|
||||||
return sum(c.capacity_sat for c in self._id_to_channel_info.values() if c.capacity_sat is not None)
|
.query(NodeInfoInDB) \
|
||||||
|
.filter_by(node_id = node_id.hex()) \
|
||||||
|
.one_or_none()
|
||||||
|
|
||||||
def get_channel_info(self, channel_id: bytes) -> Optional[ChannelInfo]:
|
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
|
||||||
return self._id_to_channel_info.get(channel_id, None)
|
adr_db = DBSession \
|
||||||
|
.query(AddressInDB) \
|
||||||
|
.filter_by(node_id = node_id.hex()) \
|
||||||
|
.order_by(AddressInDB.last_connected_date.desc()) \
|
||||||
|
.one_or_none()
|
||||||
|
if not adr_db:
|
||||||
|
return None
|
||||||
|
return LNPeerAddr(adr_db.host, adr_db.port, bytes.fromhex(adr_db.node_id))
|
||||||
|
|
||||||
|
def get_recent_peers(self):
|
||||||
|
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in DBSession \
|
||||||
|
.query(AddressInDB) \
|
||||||
|
.select_from(NodeInfoInDB) \
|
||||||
|
.order_by(AddressInDB.last_connected_date.desc()) \
|
||||||
|
.limit(self.NUM_MAX_RECENT_PEERS)]
|
||||||
|
|
||||||
|
def get_channel_info(self, channel_id: bytes):
|
||||||
|
return self.chan_query_for_id(channel_id).one_or_none()
|
||||||
|
|
||||||
def get_channels_for_node(self, node_id):
|
def get_channels_for_node(self, node_id):
|
||||||
"""Returns the set of channels that have node_id as one of the endpoints."""
|
"""Returns the set of channels that have node_id as one of the endpoints."""
|
||||||
return self._channels_for_node[node_id]
|
condition = or_(
|
||||||
|
ChannelInfoInDB.node1_id == node_id.hex(),
|
||||||
|
ChannelInfoInDB.node2_id == node_id.hex())
|
||||||
|
rows = DBSession.query(ChannelInfoInDB).filter(condition).all()
|
||||||
|
return [bytes.fromhex(x.short_channel_id) for x in rows]
|
||||||
|
|
||||||
def add_verified_channel_info(self, short_channel_id: bytes, channel_info: ChannelInfo):
|
def add_verified_channel_info(self, short_id, capacity):
|
||||||
with self.lock:
|
# called from lnchannelverifier
|
||||||
self._id_to_channel_info[short_channel_id] = channel_info
|
channel_info = self.get_channel_info(short_id)
|
||||||
self._channels_for_node[channel_info.node_id_1].add(short_channel_id)
|
channel_info.trusted = True
|
||||||
self._channels_for_node[channel_info.node_id_2].add(short_channel_id)
|
channel_info.capacity = capacity
|
||||||
|
DBSession.commit()
|
||||||
|
|
||||||
|
@profiler
|
||||||
|
def on_channel_announcement(self, msg_payloads, trusted=False):
|
||||||
|
if type(msg_payloads) is dict:
|
||||||
|
msg_payloads = [msg_payloads]
|
||||||
|
for msg in msg_payloads:
|
||||||
|
short_channel_id = msg['short_channel_id']
|
||||||
|
if DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = bh2u(short_channel_id)).count():
|
||||||
|
continue
|
||||||
|
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
|
||||||
|
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
channel_info = ChannelInfoInDB.from_msg(msg)
|
||||||
|
except UnknownEvenFeatureBits:
|
||||||
|
continue
|
||||||
|
channel_info.trusted = trusted
|
||||||
|
DBSession.add(channel_info)
|
||||||
|
if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
|
||||||
|
DBSession.commit()
|
||||||
self.network.trigger_callback('ln_status')
|
self.network.trigger_callback('ln_status')
|
||||||
|
self.update_counts()
|
||||||
|
|
||||||
def get_recent_peers(self):
|
@profiler
|
||||||
with self.lock:
|
def on_channel_update(self, msg_payloads, trusted=False):
|
||||||
return list(self._recent_peers)
|
if type(msg_payloads) is dict:
|
||||||
|
msg_payloads = [msg_payloads]
|
||||||
|
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
|
||||||
|
channel_infos_list = DBSession.query(ChannelInfoInDB).filter(ChannelInfoInDB.short_channel_id.in_(short_channel_ids)).all()
|
||||||
|
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
|
||||||
|
for msg_payload in msg_payloads:
|
||||||
|
short_channel_id = msg_payload['short_channel_id']
|
||||||
|
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
|
||||||
|
continue
|
||||||
|
channel_info = channel_infos.get(short_channel_id)
|
||||||
|
channel_info.on_channel_update(msg_payload, trusted=trusted)
|
||||||
|
DBSession.commit()
|
||||||
|
|
||||||
def add_recent_peer(self, peer: LNPeerAddr):
|
@profiler
|
||||||
with self.lock:
|
def on_node_announcement(self, msg_payloads):
|
||||||
# list is ordered
|
if type(msg_payloads) is dict:
|
||||||
if peer in self._recent_peers:
|
msg_payloads = [msg_payloads]
|
||||||
self._recent_peers.remove(peer)
|
addresses = DBSession.query(AddressInDB).all()
|
||||||
self._recent_peers.insert(0, peer)
|
have_addr = {}
|
||||||
self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS]
|
for addr in addresses:
|
||||||
self._last_good_address[peer.pubkey] = peer
|
have_addr[(addr.node_id, addr.host, addr.port)] = addr
|
||||||
|
|
||||||
def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
|
nodes = DBSession.query(NodeInfoInDB).all()
|
||||||
return self._last_good_address.get(node_id, None)
|
timestamps = {}
|
||||||
|
for node in nodes:
|
||||||
def on_channel_announcement(self, msg_payload, trusted=False):
|
no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")]
|
||||||
short_channel_id = msg_payload['short_channel_id']
|
timestamps[bfh(node.node_id)] = datetime.datetime.strptime(no_millisecs, "%Y-%m-%d %H:%M:%S")
|
||||||
if short_channel_id in self._id_to_channel_info:
|
old_addr = None
|
||||||
return
|
for msg_payload in msg_payloads:
|
||||||
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
|
pubkey = msg_payload['node_id']
|
||||||
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
|
signature = msg_payload['signature']
|
||||||
return
|
h = sha256d(msg_payload['raw'][66:])
|
||||||
try:
|
if not ecc.verify_signature(pubkey, signature, h):
|
||||||
channel_info = ChannelInfo(msg_payload)
|
continue
|
||||||
except UnknownEvenFeatureBits:
|
try:
|
||||||
return
|
new_node_info, addresses = NodeInfoInDB.from_msg(msg_payload)
|
||||||
if trusted:
|
except UnknownEvenFeatureBits:
|
||||||
self.add_verified_channel_info(short_channel_id, channel_info)
|
continue
|
||||||
else:
|
if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp:
|
||||||
self.ca_verifier.add_new_channel_info(channel_info)
|
continue # ignore
|
||||||
|
DBSession.add(new_node_info)
|
||||||
def on_channel_update(self, msg_payload, trusted=False):
|
for new_addr in addresses:
|
||||||
short_channel_id = msg_payload['short_channel_id']
|
key = (new_addr.node_id, new_addr.host, new_addr.port)
|
||||||
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
|
old_addr = have_addr.get(key)
|
||||||
return
|
if old_addr:
|
||||||
# try finding channel in pending db
|
# since old_addr is embedded in have_addr,
|
||||||
channel_info = self.ca_verifier.get_pending_channel_info(short_channel_id)
|
# it will still live when commmit is called
|
||||||
if channel_info is None:
|
old_addr.last_connected_date = new_addr.last_connected_date
|
||||||
# try finding channel in verified db
|
del new_addr
|
||||||
channel_info = self._id_to_channel_info.get(short_channel_id, None)
|
else:
|
||||||
if channel_info is None:
|
DBSession.add(new_addr)
|
||||||
self.print_error("could not find", short_channel_id)
|
have_addr[key] = new_addr
|
||||||
raise NotFoundChanAnnouncementForUpdate()
|
# TODO if this message is for a new node, and if we have no associated
|
||||||
channel_info.on_channel_update(msg_payload, trusted=trusted)
|
# channels for this node, we should ignore the message and return here,
|
||||||
|
# to mitigate DOS. but race condition: the channels we have for this
|
||||||
def on_node_announcement(self, msg_payload):
|
# node, might be under verification in self.ca_verifier, what then?
|
||||||
pubkey = msg_payload['node_id']
|
del nodes, addresses
|
||||||
signature = msg_payload['signature']
|
if old_addr:
|
||||||
h = sha256d(msg_payload['raw'][66:])
|
del old_addr
|
||||||
if not ecc.verify_signature(pubkey, signature, h):
|
DBSession.commit()
|
||||||
return
|
self.network.trigger_callback('ln_status')
|
||||||
old_node_info = self.nodes.get(pubkey, None)
|
self.update_counts()
|
||||||
try:
|
|
||||||
new_node_info = NodeInfo(msg_payload)
|
|
||||||
except UnknownEvenFeatureBits:
|
|
||||||
return
|
|
||||||
# TODO if this message is for a new node, and if we have no associated
|
|
||||||
# channels for this node, we should ignore the message and return here,
|
|
||||||
# to mitigate DOS. but race condition: the channels we have for this
|
|
||||||
# node, might be under verification in self.ca_verifier, what then?
|
|
||||||
if old_node_info and old_node_info.timestamp >= new_node_info.timestamp:
|
|
||||||
return # ignore
|
|
||||||
self.nodes[pubkey] = new_node_info
|
|
||||||
|
|
||||||
def get_routing_policy_for_channel(self, start_node_id: bytes,
|
def get_routing_policy_for_channel(self, start_node_id: bytes,
|
||||||
short_channel_id: bytes) -> Optional[ChannelInfoDirectedPolicy]:
|
short_channel_id: bytes) -> Optional[bytes]:
|
||||||
if not start_node_id or not short_channel_id: return None
|
if not start_node_id or not short_channel_id: return None
|
||||||
channel_info = self.get_channel_info(short_channel_id)
|
channel_info = self.get_channel_info(short_channel_id)
|
||||||
if channel_info is not None:
|
if channel_info is not None:
|
||||||
return channel_info.get_policy_for_node(start_node_id)
|
return channel_info.get_policy_for_node(start_node_id)
|
||||||
return self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
|
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
|
||||||
|
if not msg: return None
|
||||||
|
return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB
|
||||||
|
|
||||||
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
|
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
|
||||||
if not verify_sig_for_channel_update(msg_payload, start_node_id):
|
if not verify_sig_for_channel_update(msg_payload, start_node_id):
|
||||||
return # ignore
|
return # ignore
|
||||||
short_channel_id = msg_payload['short_channel_id']
|
short_channel_id = msg_payload['short_channel_id']
|
||||||
policy = ChannelInfoDirectedPolicy(msg_payload)
|
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
|
||||||
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = policy
|
|
||||||
|
|
||||||
def remove_channel(self, short_channel_id):
|
def remove_channel(self, short_channel_id):
|
||||||
try:
|
self.chan_query_for_id(short_channel_id).delete('evaluate')
|
||||||
channel_info = self._id_to_channel_info[short_channel_id]
|
DBSession.commit()
|
||||||
except KeyError:
|
|
||||||
self.print_error(f'remove_channel: cannot find channel {bh2u(short_channel_id)}')
|
def chan_query_for_id(self, short_channel_id) -> Query:
|
||||||
return
|
return DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = short_channel_id.hex())
|
||||||
self._id_to_channel_info.pop(short_channel_id, None)
|
|
||||||
for node in (channel_info.node_id_1, channel_info.node_id_2):
|
|
||||||
try:
|
|
||||||
self._channels_for_node[node].remove(short_channel_id)
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def print_graph(self, full_ids=False):
|
def print_graph(self, full_ids=False):
|
||||||
# used for debugging.
|
# used for debugging.
|
||||||
# FIXME there is a race here - iterables could change size from another thread
|
# FIXME there is a race here - iterables could change size from another thread
|
||||||
def other_node_id(node_id, channel_id):
|
def other_node_id(node_id, channel_id):
|
||||||
channel_info = self._id_to_channel_info[channel_id]
|
channel_info = self.get_channel_info(channel_id)
|
||||||
if node_id == channel_info.node_id_1:
|
if node_id == channel_info.node1_id:
|
||||||
other = channel_info.node_id_2
|
other = channel_info.node2_id
|
||||||
else:
|
else:
|
||||||
other = channel_info.node_id_1
|
other = channel_info.node1_id
|
||||||
return other if full_ids else other[-4:]
|
return other if full_ids else other[-4:]
|
||||||
|
|
||||||
self.print_msg('node: {(channel, other_node), ...}')
|
self.print_msg('nodes')
|
||||||
for node_id, short_channel_ids in list(self._channels_for_node.items()):
|
for node in DBSession.query(NodeInfoInDB).all():
|
||||||
short_channel_ids = {(bh2u(cid), bh2u(other_node_id(node_id, cid)))
|
self.print_msg(node)
|
||||||
for cid in short_channel_ids}
|
|
||||||
node_id = bh2u(node_id) if full_ids else bh2u(node_id[-4:])
|
|
||||||
self.print_msg('{}: {}'.format(node_id, short_channel_ids))
|
|
||||||
|
|
||||||
self.print_msg('channel: node1, node2, direction')
|
self.print_msg('channels')
|
||||||
for short_channel_id, channel_info in list(self._id_to_channel_info.items()):
|
for channel_info in DBSession.query(ChannelInfoInDB).all():
|
||||||
node1 = channel_info.node_id_1
|
node1 = channel_info.node1_id
|
||||||
node2 = channel_info.node_id_2
|
node2 = channel_info.node2_id
|
||||||
direction1 = channel_info.get_policy_for_node(node1) is not None
|
direction1 = channel_info.get_policy_for_node(node1) is not None
|
||||||
direction2 = channel_info.get_policy_for_node(node2) is not None
|
direction2 = channel_info.get_policy_for_node(node2) is not None
|
||||||
if direction1 and direction2:
|
if direction1 and direction2:
|
||||||
@@ -514,8 +506,10 @@ class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
|
|||||||
+ (amount_msat * self.fee_proportional_millionths // 1_000_000)
|
+ (amount_msat * self.fee_proportional_millionths // 1_000_000)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_channel_policy(cls, channel_policy: ChannelInfoDirectedPolicy,
|
def from_channel_policy(cls, channel_policy: 'Policy',
|
||||||
short_channel_id: bytes, end_node: bytes) -> 'RouteEdge':
|
short_channel_id: bytes, end_node: bytes) -> 'RouteEdge':
|
||||||
|
assert type(short_channel_id) is bytes
|
||||||
|
assert type(end_node) is bytes
|
||||||
return RouteEdge(end_node,
|
return RouteEdge(end_node,
|
||||||
short_channel_id,
|
short_channel_id,
|
||||||
channel_policy.fee_base_msat,
|
channel_policy.fee_base_msat,
|
||||||
@@ -582,7 +576,7 @@ class LNPathFinder(PrintError):
|
|||||||
|
|
||||||
channel_policy = channel_info.get_policy_for_node(start_node)
|
channel_policy = channel_info.get_policy_for_node(start_node)
|
||||||
if channel_policy is None: return float('inf'), 0
|
if channel_policy is None: return float('inf'), 0
|
||||||
if channel_policy.disabled: return float('inf'), 0
|
if channel_policy.is_disabled(): return float('inf'), 0
|
||||||
route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node)
|
route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node)
|
||||||
if payment_amt_msat < channel_policy.htlc_minimum_msat:
|
if payment_amt_msat < channel_policy.htlc_minimum_msat:
|
||||||
return float('inf'), 0 # payment amount too little
|
return float('inf'), 0 # payment amount too little
|
||||||
@@ -611,6 +605,8 @@ class LNPathFinder(PrintError):
|
|||||||
To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
|
To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
|
||||||
i.e. an element reads as, "to get to node_id, travel through short_channel_id"
|
i.e. an element reads as, "to get to node_id, travel through short_channel_id"
|
||||||
"""
|
"""
|
||||||
|
assert type(nodeA) is bytes
|
||||||
|
assert type(nodeB) is bytes
|
||||||
assert type(invoice_amount_msat) is int
|
assert type(invoice_amount_msat) is int
|
||||||
if my_channels is None: my_channels = []
|
if my_channels is None: my_channels = []
|
||||||
my_channels = {chan.short_channel_id: chan for chan in my_channels}
|
my_channels = {chan.short_channel_id: chan for chan in my_channels}
|
||||||
@@ -657,9 +653,10 @@ class LNPathFinder(PrintError):
|
|||||||
# so there are duplicates in the queue, that we discard now:
|
# so there are duplicates in the queue, that we discard now:
|
||||||
continue
|
continue
|
||||||
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
|
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
|
||||||
|
assert type(edge_channel_id) is bytes
|
||||||
if edge_channel_id in self.blacklist: continue
|
if edge_channel_id in self.blacklist: continue
|
||||||
channel_info = self.channel_db.get_channel_info(edge_channel_id)
|
channel_info = self.channel_db.get_channel_info(edge_channel_id)
|
||||||
edge_startnode = channel_info.node_id_2 if channel_info.node_id_1 == edge_endnode else channel_info.node_id_1
|
edge_startnode = bfh(channel_info.node2_id) if bfh(channel_info.node1_id) == edge_endnode else bfh(channel_info.node1_id)
|
||||||
inspect_edge()
|
inspect_edge()
|
||||||
else:
|
else:
|
||||||
return None # no path found
|
return None # no path found
|
||||||
@@ -682,7 +679,7 @@ class LNPathFinder(PrintError):
|
|||||||
for node_id, short_channel_id in path:
|
for node_id, short_channel_id in path:
|
||||||
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
|
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
|
||||||
if channel_policy is None:
|
if channel_policy is None:
|
||||||
raise Exception(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}')
|
raise NoChannelPolicy(short_channel_id)
|
||||||
route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id))
|
route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id))
|
||||||
prev_node_id = node_id
|
prev_node_id = node_id
|
||||||
return route
|
return route
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ from .verifier import verify_tx_is_in_block, MerkleVerificationFailure
|
|||||||
from .transaction import Transaction
|
from .transaction import Transaction
|
||||||
from .interface import GracefulDisconnect
|
from .interface import GracefulDisconnect
|
||||||
from .crypto import sha256d
|
from .crypto import sha256d
|
||||||
from .lnmsg import encode_msg
|
from .lnmsg import decode_msg, encode_msg
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .network import Network
|
from .network import Network
|
||||||
@@ -56,7 +56,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
|
|||||||
NetworkJobOnDefaultServer.__init__(self, network)
|
NetworkJobOnDefaultServer.__init__(self, network)
|
||||||
self.channel_db = channel_db
|
self.channel_db = channel_db
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
self.unverified_channel_info = {} # short_channel_id -> channel_info
|
self.unverified_channel_info = {} # short_channel_id -> msg_payload
|
||||||
# channel announcements that seem to be invalid:
|
# channel announcements that seem to be invalid:
|
||||||
self.blacklist = set() # short_channel_id
|
self.blacklist = set() # short_channel_id
|
||||||
|
|
||||||
@@ -65,19 +65,16 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
|
|||||||
self.started_verifying_channel = set() # short_channel_id
|
self.started_verifying_channel = set() # short_channel_id
|
||||||
|
|
||||||
# TODO make async; and rm self.lock completely
|
# TODO make async; and rm self.lock completely
|
||||||
def add_new_channel_info(self, channel_info):
|
def add_new_channel_info(self, short_channel_id_hex, msg_payload):
|
||||||
short_channel_id = channel_info.channel_id
|
short_channel_id = bfh(short_channel_id_hex)
|
||||||
if short_channel_id in self.unverified_channel_info:
|
if short_channel_id in self.unverified_channel_info:
|
||||||
return
|
return
|
||||||
if short_channel_id in self.blacklist:
|
if short_channel_id in self.blacklist:
|
||||||
return
|
return
|
||||||
if not verify_sigs_for_channel_announcement(channel_info.msg_payload):
|
if not verify_sigs_for_channel_announcement(msg_payload):
|
||||||
return
|
return
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.unverified_channel_info[short_channel_id] = channel_info
|
self.unverified_channel_info[short_channel_id] = msg_payload
|
||||||
|
|
||||||
def get_pending_channel_info(self, short_channel_id):
|
|
||||||
return self.unverified_channel_info.get(short_channel_id, None)
|
|
||||||
|
|
||||||
async def _start_tasks(self):
|
async def _start_tasks(self):
|
||||||
async with self.group as group:
|
async with self.group as group:
|
||||||
@@ -151,8 +148,9 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
|
|||||||
self.print_error(f"received tx does not match expected txid ({tx_hash} != {tx.txid()})")
|
self.print_error(f"received tx does not match expected txid ({tx_hash} != {tx.txid()})")
|
||||||
return
|
return
|
||||||
# check funding output
|
# check funding output
|
||||||
channel_info = self.unverified_channel_info[short_channel_id]
|
msg_payload = self.unverified_channel_info[short_channel_id]
|
||||||
chan_ann = channel_info.msg_payload
|
msg_type, chan_ann = decode_msg(msg_payload)
|
||||||
|
assert msg_type == 'channel_announcement'
|
||||||
redeem_script = funding_output_script_from_keys(chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2'])
|
redeem_script = funding_output_script_from_keys(chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2'])
|
||||||
expected_address = bitcoin.redeem_script_to_address('p2wsh', redeem_script)
|
expected_address = bitcoin.redeem_script_to_address('p2wsh', redeem_script)
|
||||||
output_idx = invert_short_channel_id(short_channel_id)[2]
|
output_idx = invert_short_channel_id(short_channel_id)[2]
|
||||||
@@ -167,8 +165,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
|
|||||||
self._remove_channel_from_unverified_db(short_channel_id)
|
self._remove_channel_from_unverified_db(short_channel_id)
|
||||||
return
|
return
|
||||||
# put channel into channel DB
|
# put channel into channel DB
|
||||||
channel_info.set_capacity(actual_output.value)
|
self.channel_db.add_verified_channel_info(short_channel_id, actual_output.value)
|
||||||
self.channel_db.add_verified_channel_info(short_channel_id, channel_info)
|
|
||||||
self._remove_channel_from_unverified_db(short_channel_id)
|
self._remove_channel_from_unverified_db(short_channel_id)
|
||||||
|
|
||||||
def _remove_channel_from_unverified_db(self, short_channel_id: bytes):
|
def _remove_channel_from_unverified_db(self, short_channel_id: bytes):
|
||||||
@@ -183,8 +180,9 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
|
|||||||
self.unverified_channel_info.pop(short_channel_id, None)
|
self.unverified_channel_info.pop(short_channel_id, None)
|
||||||
|
|
||||||
|
|
||||||
def verify_sigs_for_channel_announcement(chan_ann: dict) -> bool:
|
def verify_sigs_for_channel_announcement(msg_bytes: bytes) -> bool:
|
||||||
msg_bytes = encode_msg('channel_announcement', **chan_ann)
|
msg_type, chan_ann = decode_msg(msg_bytes)
|
||||||
|
assert msg_type == 'channel_announcement'
|
||||||
pre_hash = msg_bytes[2+256:]
|
pre_hash = msg_bytes[2+256:]
|
||||||
h = sha256d(pre_hash)
|
h = sha256d(pre_hash)
|
||||||
pubkeys = [chan_ann['node_id_1'], chan_ann['node_id_2'], chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2']]
|
pubkeys = [chan_ann['node_id_1'], chan_ann['node_id_2'], chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2']]
|
||||||
|
|||||||
@@ -56,7 +56,8 @@ GRAPH_DOWNLOAD_SECONDS = 600
|
|||||||
|
|
||||||
FALLBACK_NODE_LIST_TESTNET = (
|
FALLBACK_NODE_LIST_TESTNET = (
|
||||||
LNPeerAddr('ecdsa.net', 9735, bfh('038370f0e7a03eded3e1d41dc081084a87f0afa1c5b22090b4f3abb391eb15d8ff')),
|
LNPeerAddr('ecdsa.net', 9735, bfh('038370f0e7a03eded3e1d41dc081084a87f0afa1c5b22090b4f3abb391eb15d8ff')),
|
||||||
LNPeerAddr('180.181.208.42', 9735, bfh('038863cf8ab91046230f561cd5b386cbff8309fa02e3f0c3ed161a3aeb64a643b9')),
|
LNPeerAddr('148.251.87.112', 9735, bfh('021a8bd8d8f1f2e208992a2eb755cdc74d44e66b6a0c924d3a3cce949123b9ce40')), # janus test server
|
||||||
|
LNPeerAddr('122.199.61.90', 9735, bfh('038863cf8ab91046230f561cd5b386cbff8309fa02e3f0c3ed161a3aeb64a643b9')), # popular node https://1ml.com/testnet/node/038863cf8ab91046230f561cd5b386cbff8309fa02e3f0c3ed161a3aeb64a643b9
|
||||||
)
|
)
|
||||||
FALLBACK_NODE_LIST_MAINNET = (
|
FALLBACK_NODE_LIST_MAINNET = (
|
||||||
LNPeerAddr('104.198.32.198', 9735, bfh('02f6725f9c1c40333b67faea92fd211c183050f28df32cac3f9d69685fe9665432')), # Blockstream
|
LNPeerAddr('104.198.32.198', 9735, bfh('02f6725f9c1c40333b67faea92fd211c183050f28df32cac3f9d69685fe9665432')), # Blockstream
|
||||||
@@ -420,26 +421,33 @@ class LNWorker(PrintError):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
|
def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
|
||||||
|
assert len(addr_list) >= 1
|
||||||
# choose first one that is an IP
|
# choose first one that is an IP
|
||||||
for host, port in addr_list:
|
for addr_in_db in addr_list:
|
||||||
|
host = addr_in_db.host
|
||||||
|
port = addr_in_db.port
|
||||||
if is_ip_address(host):
|
if is_ip_address(host):
|
||||||
return host, port
|
return host, port
|
||||||
# otherwise choose one at random
|
# otherwise choose one at random
|
||||||
# TODO maybe filter out onion if not on tor?
|
# TODO maybe filter out onion if not on tor?
|
||||||
return random.choice(addr_list)
|
choice = random.choice(addr_list)
|
||||||
|
return choice.host, choice.port
|
||||||
|
|
||||||
def open_channel(self, connect_contents, local_amt_sat, push_amt_sat, password=None, timeout=5):
|
def open_channel(self, connect_contents, local_amt_sat, push_amt_sat, password=None, timeout=5):
|
||||||
node_id, rest = extract_nodeid(connect_contents)
|
node_id, rest = extract_nodeid(connect_contents)
|
||||||
peer = self.peers.get(node_id)
|
peer = self.peers.get(node_id)
|
||||||
if not peer:
|
if not peer:
|
||||||
all_nodes = self.network.channel_db.nodes
|
nodes_get = self.network.channel_db.nodes_get
|
||||||
node_info = all_nodes.get(node_id, None)
|
node_info = nodes_get(node_id)
|
||||||
if rest is not None:
|
if rest is not None:
|
||||||
host, port = split_host_port(rest)
|
host, port = split_host_port(rest)
|
||||||
elif node_info and len(node_info.addresses) > 0:
|
|
||||||
host, port = self.choose_preferred_address(node_info.addresses)
|
|
||||||
else:
|
else:
|
||||||
raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id))
|
if not node_info:
|
||||||
|
raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id))
|
||||||
|
addrs = node_info.get_addresses()
|
||||||
|
if len(addrs) == 0:
|
||||||
|
raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id))
|
||||||
|
host, port = self.choose_preferred_address(addrs)
|
||||||
try:
|
try:
|
||||||
socket.getaddrinfo(host, int(port))
|
socket.getaddrinfo(host, int(port))
|
||||||
except socket.gaierror:
|
except socket.gaierror:
|
||||||
@@ -457,7 +465,7 @@ class LNWorker(PrintError):
|
|||||||
This is not merged with _pay so that we can run the test with
|
This is not merged with _pay so that we can run the test with
|
||||||
one thread only.
|
one thread only.
|
||||||
"""
|
"""
|
||||||
addr, peer, coro = self._pay(invoice, amount_sat)
|
addr, peer, coro = self.network.run_from_another_thread(self._pay(invoice, amount_sat))
|
||||||
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
|
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
|
||||||
return addr, peer, fut
|
return addr, peer, fut
|
||||||
|
|
||||||
@@ -467,9 +475,9 @@ class LNWorker(PrintError):
|
|||||||
if chan.short_channel_id == short_channel_id:
|
if chan.short_channel_id == short_channel_id:
|
||||||
return chan
|
return chan
|
||||||
|
|
||||||
def _pay(self, invoice, amount_sat=None):
|
async def _pay(self, invoice, amount_sat=None, same_thread=False):
|
||||||
addr = self._check_invoice(invoice, amount_sat)
|
addr = self._check_invoice(invoice, amount_sat)
|
||||||
route = self._create_route_from_invoice(decoded_invoice=addr)
|
route = await self._create_route_from_invoice(decoded_invoice=addr)
|
||||||
peer = self.peers[route[0].node_id]
|
peer = self.peers[route[0].node_id]
|
||||||
if not self.get_channel_by_short_id(route[0].short_channel_id):
|
if not self.get_channel_by_short_id(route[0].short_channel_id):
|
||||||
assert False, 'Found route with short channel ID we don\'t have: ' + repr(route[0].short_channel_id)
|
assert False, 'Found route with short channel ID we don\'t have: ' + repr(route[0].short_channel_id)
|
||||||
@@ -498,7 +506,7 @@ class LNWorker(PrintError):
|
|||||||
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
|
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
|
||||||
return addr
|
return addr
|
||||||
|
|
||||||
def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]:
|
async def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]:
|
||||||
amount_msat = int(decoded_invoice.amount * COIN * 1000)
|
amount_msat = int(decoded_invoice.amount * COIN * 1000)
|
||||||
invoice_pubkey = decoded_invoice.pubkey.serialize()
|
invoice_pubkey = decoded_invoice.pubkey.serialize()
|
||||||
# use 'r' field from invoice
|
# use 'r' field from invoice
|
||||||
@@ -699,20 +707,14 @@ class LNWorker(PrintError):
|
|||||||
if peer in self._last_tried_peer: continue
|
if peer in self._last_tried_peer: continue
|
||||||
return [peer]
|
return [peer]
|
||||||
# try random peer from graph
|
# try random peer from graph
|
||||||
all_nodes = self.channel_db.nodes
|
unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys())
|
||||||
if all_nodes:
|
if unconnected_nodes:
|
||||||
#self.print_error('trying to get ln peers from channel db')
|
for node in unconnected_nodes:
|
||||||
node_ids = list(all_nodes)
|
addrs = node.get_addresses()
|
||||||
max_tries = min(200, len(all_nodes))
|
if not addrs:
|
||||||
for i in range(max_tries):
|
continue
|
||||||
node_id = random.choice(node_ids)
|
host, port = self.choose_preferred_address(addrs)
|
||||||
node = all_nodes.get(node_id)
|
peer = LNPeerAddr(host, port, bytes.fromhex(node.node_id))
|
||||||
if node is None: continue
|
|
||||||
addresses = node.addresses
|
|
||||||
if not addresses: continue
|
|
||||||
host, port = self.choose_preferred_address(addresses)
|
|
||||||
peer = LNPeerAddr(host, port, node_id)
|
|
||||||
if peer.pubkey in self.peers: continue
|
|
||||||
if peer in self._last_tried_peer: continue
|
if peer in self._last_tried_peer: continue
|
||||||
self.print_error('taking random ln peer from our channel db')
|
self.print_error('taking random ln peer from our channel db')
|
||||||
return [peer]
|
return [peer]
|
||||||
@@ -772,11 +774,12 @@ class LNWorker(PrintError):
|
|||||||
await self.add_peer(peer.host, peer.port, peer.pubkey)
|
await self.add_peer(peer.host, peer.port, peer.pubkey)
|
||||||
return
|
return
|
||||||
# try random address for node_id
|
# try random address for node_id
|
||||||
node_info = self.channel_db.nodes.get(chan.node_id, None)
|
node_info = await self.channel_db._nodes_get(chan.node_id)
|
||||||
if not node_info: return
|
if not node_info: return
|
||||||
addresses = node_info.addresses
|
addresses = node_info.get_addresses()
|
||||||
if not addresses: return
|
if not addresses: return
|
||||||
host, port = random.choice(addresses)
|
adr_obj = random.choice(addresses)
|
||||||
|
host, port = adr_obj.host, adr_obj.port
|
||||||
peer = LNPeerAddr(host, port, chan.node_id)
|
peer = LNPeerAddr(host, port, chan.node_id)
|
||||||
last_tried = self._last_tried_peer.get(peer, 0)
|
last_tried = self._last_tried_peer.get(peer, 0)
|
||||||
if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now:
|
if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now:
|
||||||
|
|||||||
@@ -1181,7 +1181,6 @@ class Network(Logger):
|
|||||||
def stop(self):
|
def stop(self):
|
||||||
assert self._loop_thread != threading.current_thread(), 'must not be called from network thread'
|
assert self._loop_thread != threading.current_thread(), 'must not be called from network thread'
|
||||||
fut = asyncio.run_coroutine_threadsafe(self._stop(full_shutdown=True), self.asyncio_loop)
|
fut = asyncio.run_coroutine_threadsafe(self._stop(full_shutdown=True), self.asyncio_loop)
|
||||||
self.channel_db.save_data()
|
|
||||||
try:
|
try:
|
||||||
fut.result(timeout=2)
|
fut.result(timeout=2)
|
||||||
except (asyncio.TimeoutError, asyncio.CancelledError): pass
|
except (asyncio.TimeoutError, asyncio.CancelledError): pass
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ class TestPeer(unittest.TestCase):
|
|||||||
fut = self.prepare_ln_message_future(w2)
|
fut = self.prepare_ln_message_future(w2)
|
||||||
|
|
||||||
async def pay():
|
async def pay():
|
||||||
addr, peer, coro = LNWorker._pay(w1, pay_req)
|
addr, peer, coro = await LNWorker._pay(w1, pay_req, same_thread=True)
|
||||||
await coro
|
await coro
|
||||||
print("HTLC ADDED")
|
print("HTLC ADDED")
|
||||||
self.assertEqual(await fut, 'Payment received')
|
self.assertEqual(await fut, 'Payment received')
|
||||||
@@ -240,14 +240,14 @@ class TestPeer(unittest.TestCase):
|
|||||||
pay_req = self.prepare_invoice(w2)
|
pay_req = self.prepare_invoice(w2)
|
||||||
|
|
||||||
addr = w1._check_invoice(pay_req)
|
addr = w1._check_invoice(pay_req)
|
||||||
route = w1._create_route_from_invoice(decoded_invoice=addr)
|
route = run(w1._create_route_from_invoice(decoded_invoice=addr))
|
||||||
|
|
||||||
run(w1.force_close_channel(self.alice_channel.channel_id))
|
run(w1.force_close_channel(self.alice_channel.channel_id))
|
||||||
# check if a tx (commitment transaction) was broadcasted:
|
# check if a tx (commitment transaction) was broadcasted:
|
||||||
assert q1.qsize() == 1
|
assert q1.qsize() == 1
|
||||||
|
|
||||||
with self.assertRaises(PaymentFailure) as e:
|
with self.assertRaises(PaymentFailure) as e:
|
||||||
w1._create_route_from_invoice(decoded_invoice=addr)
|
run(w1._create_route_from_invoice(decoded_invoice=addr))
|
||||||
self.assertEqual(str(e.exception), 'No path found')
|
self.assertEqual(str(e.exception), 'No path found')
|
||||||
|
|
||||||
peer = w1.peers[route[0].node_id]
|
peer = w1.peers[route[0].node_id]
|
||||||
@@ -257,4 +257,4 @@ class TestPeer(unittest.TestCase):
|
|||||||
run(asyncio.gather(w1._pay_to_route(route, addr, pay_req), p1._main_loop(), p2._main_loop()))
|
run(asyncio.gather(w1._pay_to_route(route, addr, pay_req), p1._main_loop(), p2._main_loop()))
|
||||||
|
|
||||||
def run(coro):
|
def run(coro):
|
||||||
asyncio.get_event_loop().run_until_complete(coro)
|
return asyncio.get_event_loop().run_until_complete(coro)
|
||||||
|
|||||||
@@ -45,15 +45,17 @@ class Test_LNRouter(TestCaseForTestnet):
|
|||||||
asyncio_loop = asyncio.get_event_loop()
|
asyncio_loop = asyncio.get_event_loop()
|
||||||
trigger_callback = lambda *args: None
|
trigger_callback = lambda *args: None
|
||||||
register_callback = lambda *args: None
|
register_callback = lambda *args: None
|
||||||
async def add_job(self, *args): return None
|
interface = None
|
||||||
fake_network.channel_db = lnrouter.ChannelDB(fake_network())
|
fake_network.channel_db = lnrouter.ChannelDB(fake_network())
|
||||||
cdb = fake_network.channel_db
|
cdb = fake_network.channel_db
|
||||||
path_finder = lnrouter.LNPathFinder(cdb)
|
path_finder = lnrouter.LNPathFinder(cdb)
|
||||||
|
self.assertEqual(cdb.num_channels, 0)
|
||||||
cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02cccccccccccccccccccccccccccccccc',
|
cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02cccccccccccccccccccccccccccccccc',
|
||||||
'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02cccccccccccccccccccccccccccccccc',
|
'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02cccccccccccccccccccccccccccccccc',
|
||||||
'short_channel_id': bfh('0000000000000001'),
|
'short_channel_id': bfh('0000000000000001'),
|
||||||
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
|
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
|
||||||
'len': b'\x00\x00', 'features': b''}, trusted=True)
|
'len': b'\x00\x00', 'features': b''}, trusted=True)
|
||||||
|
self.assertEqual(cdb.num_channels, 1)
|
||||||
cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
|
cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
|
||||||
'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
|
'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
|
||||||
'short_channel_id': bfh('0000000000000002'),
|
'short_channel_id': bfh('0000000000000002'),
|
||||||
@@ -92,12 +94,16 @@ class Test_LNRouter(TestCaseForTestnet):
|
|||||||
cdb.on_channel_update({'short_channel_id': bfh('0000000000000005'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(999), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True)
|
cdb.on_channel_update({'short_channel_id': bfh('0000000000000005'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(999), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True)
|
||||||
cdb.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(99999999), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True)
|
cdb.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(99999999), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True)
|
||||||
cdb.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(150), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True)
|
cdb.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(150), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True)
|
||||||
self.assertNotEqual(None, path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000))
|
path = path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000)
|
||||||
self.assertEqual([(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', b'\x00\x00\x00\x00\x00\x00\x00\x03'),
|
self.assertEqual([(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', b'\x00\x00\x00\x00\x00\x00\x00\x03'),
|
||||||
(b'\x02cccccccccccccccccccccccccccccccc', b'\x00\x00\x00\x00\x00\x00\x00\x01'),
|
(b'\x02cccccccccccccccccccccccccccccccc', b'\x00\x00\x00\x00\x00\x00\x00\x01'),
|
||||||
(b'\x02dddddddddddddddddddddddddddddddd', b'\x00\x00\x00\x00\x00\x00\x00\x04'),
|
(b'\x02dddddddddddddddddddddddddddddddd', b'\x00\x00\x00\x00\x00\x00\x00\x04'),
|
||||||
(b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', b'\x00\x00\x00\x00\x00\x00\x00\x05')],
|
(b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', b'\x00\x00\x00\x00\x00\x00\x00\x05')
|
||||||
path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000))
|
], path)
|
||||||
|
start_node = b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb'
|
||||||
|
route = path_finder.create_route_from_path(path, start_node)
|
||||||
|
self.assertEqual(route[0].node_id, start_node)
|
||||||
|
self.assertEqual(route[0].short_channel_id, bfh('0000000000000003'))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user