fix sql conflicts in lnrouter
This commit is contained in:
@@ -23,7 +23,7 @@
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import datetime
|
||||
import time
|
||||
import random
|
||||
import queue
|
||||
import os
|
||||
@@ -35,7 +35,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
|
||||
import binascii
|
||||
import base64
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
|
||||
from sqlalchemy.orm.query import Query
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.sql import not_, or_
|
||||
@@ -81,14 +81,14 @@ class ChannelInfo(Base):
|
||||
trusted = Column(Boolean, nullable=False)
|
||||
|
||||
@staticmethod
|
||||
def from_msg(channel_announcement_payload):
|
||||
features = int.from_bytes(channel_announcement_payload['features'], 'big')
|
||||
def from_msg(payload):
|
||||
features = int.from_bytes(payload['features'], 'big')
|
||||
validate_features(features)
|
||||
channel_id = channel_announcement_payload['short_channel_id'].hex()
|
||||
node_id_1 = channel_announcement_payload['node_id_1'].hex()
|
||||
node_id_2 = channel_announcement_payload['node_id_2'].hex()
|
||||
channel_id = payload['short_channel_id'].hex()
|
||||
node_id_1 = payload['node_id_1'].hex()
|
||||
node_id_2 = payload['node_id_2'].hex()
|
||||
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
|
||||
msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex()
|
||||
msg_payload_hex = encode_msg('channel_announcement', **payload).hex()
|
||||
capacity_sat = None
|
||||
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
|
||||
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
|
||||
@@ -109,17 +109,17 @@ class Policy(Base):
|
||||
fee_base_msat = Column(Integer, nullable=False)
|
||||
fee_proportional_millionths = Column(Integer, nullable=False)
|
||||
channel_flags = Column(Integer, nullable=False)
|
||||
timestamp = Column(DateTime, nullable=False)
|
||||
timestamp = Column(Integer, nullable=False)
|
||||
|
||||
@staticmethod
|
||||
def from_msg(channel_update_payload, start_node, short_channel_id):
|
||||
cltv_expiry_delta = channel_update_payload['cltv_expiry_delta']
|
||||
htlc_minimum_msat = channel_update_payload['htlc_minimum_msat']
|
||||
fee_base_msat = channel_update_payload['fee_base_msat']
|
||||
fee_proportional_millionths = channel_update_payload['fee_proportional_millionths']
|
||||
channel_flags = channel_update_payload['channel_flags']
|
||||
timestamp = channel_update_payload['timestamp']
|
||||
htlc_maximum_msat = channel_update_payload.get('htlc_maximum_msat') # optional
|
||||
def from_msg(payload, start_node, short_channel_id):
|
||||
cltv_expiry_delta = payload['cltv_expiry_delta']
|
||||
htlc_minimum_msat = payload['htlc_minimum_msat']
|
||||
fee_base_msat = payload['fee_base_msat']
|
||||
fee_proportional_millionths = payload['fee_proportional_millionths']
|
||||
channel_flags = payload['channel_flags']
|
||||
timestamp = payload['timestamp']
|
||||
htlc_maximum_msat = payload.get('htlc_maximum_msat') # optional
|
||||
|
||||
cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
|
||||
htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
|
||||
@@ -127,7 +127,7 @@ class Policy(Base):
|
||||
fee_base_msat = int.from_bytes(fee_base_msat, "big")
|
||||
fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
|
||||
channel_flags = int.from_bytes(channel_flags, "big")
|
||||
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(timestamp, "big"))
|
||||
timestamp = int.from_bytes(timestamp, "big")
|
||||
|
||||
return Policy(start_node=start_node,
|
||||
short_channel_id=short_channel_id,
|
||||
@@ -150,17 +150,16 @@ class NodeInfo(Base):
|
||||
alias = Column(String(64), nullable=False)
|
||||
|
||||
@staticmethod
|
||||
def from_msg(node_announcement_payload, addresses_already_parsed=False):
|
||||
node_id = node_announcement_payload['node_id'].hex()
|
||||
features = int.from_bytes(node_announcement_payload['features'], "big")
|
||||
def from_msg(payload):
|
||||
node_id = payload['node_id'].hex()
|
||||
features = int.from_bytes(payload['features'], "big")
|
||||
validate_features(features)
|
||||
if not addresses_already_parsed:
|
||||
addresses = NodeInfo.parse_addresses_field(node_announcement_payload['addresses'])
|
||||
else:
|
||||
addresses = node_announcement_payload['addresses']
|
||||
alias = node_announcement_payload['alias'].rstrip(b'\x00').hex()
|
||||
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big"))
|
||||
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [Address(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses]
|
||||
addresses = NodeInfo.parse_addresses_field(payload['addresses'])
|
||||
alias = payload['alias'].rstrip(b'\x00').hex()
|
||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||
now = int(time.time())
|
||||
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [
|
||||
Address(host=host, port=port, node_id=node_id, last_connected_date=now) for host, port in addresses]
|
||||
|
||||
@staticmethod
|
||||
def parse_addresses_field(addresses_field):
|
||||
@@ -207,7 +206,7 @@ class Address(Base):
|
||||
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
|
||||
host = Column(String(256), primary_key=True)
|
||||
port = Column(Integer, primary_key=True)
|
||||
last_connected_date = Column(DateTime(), nullable=False)
|
||||
last_connected_date = Column(Integer(), nullable=False)
|
||||
|
||||
|
||||
|
||||
@@ -235,12 +234,14 @@ class ChannelDB(SqlDB):
|
||||
|
||||
@sql
|
||||
def add_recent_peer(self, peer: LNPeerAddr):
|
||||
addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none()
|
||||
if addr is None:
|
||||
addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
|
||||
now = int(time.time())
|
||||
node_id = peer.pubkey.hex()
|
||||
addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
|
||||
if addr:
|
||||
addr.last_connected_date = now
|
||||
else:
|
||||
addr.last_connected_date = datetime.datetime.now()
|
||||
self.DBSession.add(addr)
|
||||
addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
|
||||
self.DBSession.add(addr)
|
||||
self.DBSession.commit()
|
||||
|
||||
@sql
|
||||
@@ -317,25 +318,31 @@ class ChannelDB(SqlDB):
|
||||
self.DBSession.commit()
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
#@profiler
|
||||
def on_channel_announcement(self, msg_payloads, trusted=False):
|
||||
if type(msg_payloads) is dict:
|
||||
msg_payloads = [msg_payloads]
|
||||
new_channels = {}
|
||||
for msg in msg_payloads:
|
||||
short_channel_id = msg['short_channel_id']
|
||||
if self.DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count():
|
||||
short_channel_id = bh2u(msg['short_channel_id'])
|
||||
if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count():
|
||||
continue
|
||||
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
|
||||
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
|
||||
self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
|
||||
continue
|
||||
try:
|
||||
channel_info = ChannelInfo.from_msg(msg)
|
||||
except UnknownEvenFeatureBits:
|
||||
self.print_error("unknown feature bits")
|
||||
continue
|
||||
channel_info.trusted = trusted
|
||||
new_channels[short_channel_id] = channel_info
|
||||
if not trusted:
|
||||
self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
|
||||
for channel_info in new_channels.values():
|
||||
self.DBSession.add(channel_info)
|
||||
if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
|
||||
self.DBSession.commit()
|
||||
self.print_error('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
|
||||
self._update_counts()
|
||||
self.network.trigger_callback('ln_status')
|
||||
|
||||
@@ -379,21 +386,13 @@ class ChannelDB(SqlDB):
|
||||
self.DBSession.commit()
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
#@profiler
|
||||
def on_node_announcement(self, msg_payloads):
|
||||
if type(msg_payloads) is dict:
|
||||
msg_payloads = [msg_payloads]
|
||||
addresses = self.DBSession.query(Address).all()
|
||||
have_addr = {}
|
||||
for addr in addresses:
|
||||
have_addr[(addr.node_id, addr.host, addr.port)] = addr
|
||||
|
||||
nodes = self.DBSession.query(NodeInfo).all()
|
||||
timestamps = {}
|
||||
for node in nodes:
|
||||
no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")]
|
||||
timestamps[bfh(node.node_id)] = datetime.datetime.strptime(no_millisecs, "%Y-%m-%d %H:%M:%S")
|
||||
old_addr = None
|
||||
new_nodes = {}
|
||||
new_addresses = {}
|
||||
for msg_payload in msg_payloads:
|
||||
pubkey = msg_payload['node_id']
|
||||
signature = msg_payload['signature']
|
||||
@@ -401,30 +400,33 @@ class ChannelDB(SqlDB):
|
||||
if not ecc.verify_signature(pubkey, signature, h):
|
||||
continue
|
||||
try:
|
||||
new_node_info, addresses = NodeInfo.from_msg(msg_payload)
|
||||
node_info, node_addresses = NodeInfo.from_msg(msg_payload)
|
||||
except UnknownEvenFeatureBits:
|
||||
continue
|
||||
if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp:
|
||||
continue # ignore
|
||||
self.DBSession.add(new_node_info)
|
||||
for new_addr in addresses:
|
||||
key = (new_addr.node_id, new_addr.host, new_addr.port)
|
||||
old_addr = have_addr.get(key)
|
||||
if old_addr:
|
||||
# since old_addr is embedded in have_addr,
|
||||
# it will still live when commmit is called
|
||||
old_addr.last_connected_date = new_addr.last_connected_date
|
||||
del new_addr
|
||||
else:
|
||||
self.DBSession.add(new_addr)
|
||||
have_addr[key] = new_addr
|
||||
node_id = node_info.node_id
|
||||
node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none()
|
||||
if node and node.timestamp >= node_info.timestamp:
|
||||
continue
|
||||
node = new_nodes.get(node_id)
|
||||
if node and node.timestamp >= node_info.timestamp:
|
||||
continue
|
||||
new_nodes[node_id] = node_info
|
||||
for addr in node_addresses:
|
||||
new_addresses[(addr.node_id,addr.host,addr.port)] = addr
|
||||
|
||||
self.print_error("on_node_announcements: %d/%d"%(len(new_nodes), len(msg_payloads)))
|
||||
for node_info in new_nodes.values():
|
||||
self.DBSession.add(node_info)
|
||||
for new_addr in new_addresses.values():
|
||||
old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
|
||||
if old_addr:
|
||||
old_addr.last_connected_date = new_addr.last_connected_date
|
||||
else:
|
||||
self.DBSession.add(new_addr)
|
||||
# TODO if this message is for a new node, and if we have no associated
|
||||
# channels for this node, we should ignore the message and return here,
|
||||
# to mitigate DOS. but race condition: the channels we have for this
|
||||
# node, might be under verification in self.ca_verifier, what then?
|
||||
del nodes, addresses
|
||||
if old_addr:
|
||||
del old_addr
|
||||
self.DBSession.commit()
|
||||
self._update_counts()
|
||||
self.network.trigger_callback('ln_status')
|
||||
|
||||
Reference in New Issue
Block a user