persist channel db on disk. verify channel gossip sigs.
This commit is contained in:
@@ -30,8 +30,11 @@ import sys
|
||||
import binascii
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import json
|
||||
import threading
|
||||
from collections import namedtuple, defaultdict
|
||||
from typing import Sequence, Union, Tuple
|
||||
from typing import Sequence, Union, Tuple, Optional
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
@@ -39,9 +42,12 @@ from cryptography.hazmat.backends import default_backend
|
||||
from . import bitcoin
|
||||
from . import ecc
|
||||
from . import crypto
|
||||
from . import constants
|
||||
from .crypto import sha256
|
||||
from .util import PrintError, bh2u, profiler, xor_bytes
|
||||
from .util import PrintError, bh2u, profiler, xor_bytes, get_headers_dir, bfh
|
||||
from .lnutil import get_ecdh
|
||||
from .storage import JsonDB
|
||||
from .lnchanannverifier import LNChanAnnVerifier, verify_sig_for_channel_update
|
||||
|
||||
|
||||
class ChannelInfo(PrintError):
|
||||
@@ -54,23 +60,71 @@ class ChannelInfo(PrintError):
|
||||
assert type(self.node_id_2) is bytes
|
||||
assert list(sorted([self.node_id_1, self.node_id_2])) == [self.node_id_1, self.node_id_2]
|
||||
|
||||
self.features_len = channel_announcement_payload['len']
|
||||
self.features = channel_announcement_payload['features']
|
||||
self.bitcoin_key_1 = channel_announcement_payload['bitcoin_key_1']
|
||||
self.bitcoin_key_2 = channel_announcement_payload['bitcoin_key_2']
|
||||
|
||||
# this field does not get persisted
|
||||
self.msg_payload = channel_announcement_payload
|
||||
|
||||
self.capacity_sat = None
|
||||
self.policy_node1 = None
|
||||
self.policy_node2 = None
|
||||
|
||||
def to_json(self) -> dict:
|
||||
d = {}
|
||||
d['short_channel_id'] = bh2u(self.channel_id)
|
||||
d['node_id_1'] = bh2u(self.node_id_1)
|
||||
d['node_id_2'] = bh2u(self.node_id_2)
|
||||
d['len'] = bh2u(self.features_len)
|
||||
d['features'] = bh2u(self.features)
|
||||
d['bitcoin_key_1'] = bh2u(self.bitcoin_key_1)
|
||||
d['bitcoin_key_2'] = bh2u(self.bitcoin_key_2)
|
||||
d['policy_node1'] = self.policy_node1
|
||||
d['policy_node2'] = self.policy_node2
|
||||
d['capacity_sat'] = self.capacity_sat
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, d: dict):
|
||||
d2 = {}
|
||||
d2['short_channel_id'] = bfh(d['short_channel_id'])
|
||||
d2['node_id_1'] = bfh(d['node_id_1'])
|
||||
d2['node_id_2'] = bfh(d['node_id_2'])
|
||||
d2['len'] = bfh(d['len'])
|
||||
d2['features'] = bfh(d['features'])
|
||||
d2['bitcoin_key_1'] = bfh(d['bitcoin_key_1'])
|
||||
d2['bitcoin_key_2'] = bfh(d['bitcoin_key_2'])
|
||||
ci = ChannelInfo(d2)
|
||||
ci.capacity_sat = d['capacity_sat']
|
||||
ci.policy_node1 = ChannelInfoDirectedPolicy.from_json(d['policy_node1'])
|
||||
ci.policy_node2 = ChannelInfoDirectedPolicy.from_json(d['policy_node2'])
|
||||
return ci
|
||||
|
||||
def set_capacity(self, capacity):
|
||||
# TODO call this after looking up UTXO for funding txn on chain
|
||||
self.capacity_sat = capacity
|
||||
|
||||
def on_channel_update(self, msg_payload):
|
||||
assert self.channel_id == msg_payload['short_channel_id']
|
||||
flags = int.from_bytes(msg_payload['flags'], 'big')
|
||||
direction = flags & 1
|
||||
new_policy = ChannelInfoDirectedPolicy(msg_payload)
|
||||
if direction == 0:
|
||||
self.policy_node1 = ChannelInfoDirectedPolicy(msg_payload)
|
||||
old_policy = self.policy_node1
|
||||
node_id = self.node_id_1
|
||||
else:
|
||||
self.policy_node2 = ChannelInfoDirectedPolicy(msg_payload)
|
||||
#self.print_error('channel update', binascii.hexlify(self.channel_id).decode("ascii"), flags)
|
||||
old_policy = self.policy_node2
|
||||
node_id = self.node_id_2
|
||||
if old_policy and old_policy.timestamp >= new_policy.timestamp:
|
||||
return # ignore
|
||||
if not verify_sig_for_channel_update(msg_payload, node_id):
|
||||
return # ignore
|
||||
# save new policy
|
||||
if direction == 0:
|
||||
self.policy_node1 = new_policy
|
||||
else:
|
||||
self.policy_node2 = new_policy
|
||||
|
||||
def get_policy_for_node(self, node_id):
|
||||
if node_id == self.node_id_1:
|
||||
@@ -84,51 +138,121 @@ class ChannelInfo(PrintError):
|
||||
class ChannelInfoDirectedPolicy:
|
||||
|
||||
def __init__(self, channel_update_payload):
|
||||
self.cltv_expiry_delta = channel_update_payload['cltv_expiry_delta']
|
||||
self.htlc_minimum_msat = channel_update_payload['htlc_minimum_msat']
|
||||
self.fee_base_msat = channel_update_payload['fee_base_msat']
|
||||
self.fee_proportional_millionths = channel_update_payload['fee_proportional_millionths']
|
||||
self.cltv_expiry_delta = int.from_bytes(self.cltv_expiry_delta, "big")
|
||||
self.htlc_minimum_msat = int.from_bytes(self.htlc_minimum_msat, "big")
|
||||
self.fee_base_msat = int.from_bytes(self.fee_base_msat, "big")
|
||||
self.fee_proportional_millionths = int.from_bytes(self.fee_proportional_millionths, "big")
|
||||
cltv_expiry_delta = channel_update_payload['cltv_expiry_delta']
|
||||
htlc_minimum_msat = channel_update_payload['htlc_minimum_msat']
|
||||
fee_base_msat = channel_update_payload['fee_base_msat']
|
||||
fee_proportional_millionths = channel_update_payload['fee_proportional_millionths']
|
||||
flags = channel_update_payload['flags']
|
||||
timestamp = channel_update_payload['timestamp']
|
||||
|
||||
self.cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
|
||||
self.htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
|
||||
self.fee_base_msat = int.from_bytes(fee_base_msat, "big")
|
||||
self.fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
|
||||
self.flags = int.from_bytes(flags, "big")
|
||||
self.timestamp = int.from_bytes(timestamp, "big")
|
||||
|
||||
def to_json(self) -> dict:
|
||||
d = {}
|
||||
d['cltv_expiry_delta'] = self.cltv_expiry_delta
|
||||
d['htlc_minimum_msat'] = self.htlc_minimum_msat
|
||||
d['fee_base_msat'] = self.fee_base_msat
|
||||
d['fee_proportional_millionths'] = self.fee_proportional_millionths
|
||||
d['flags'] = self.flags
|
||||
d['timestamp'] = self.timestamp
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, d: dict):
|
||||
if d is None: return None
|
||||
d2 = {}
|
||||
d2['cltv_expiry_delta'] = d['cltv_expiry_delta'].to_bytes(2, "big")
|
||||
d2['htlc_minimum_msat'] = d['htlc_minimum_msat'].to_bytes(8, "big")
|
||||
d2['fee_base_msat'] = d['fee_base_msat'].to_bytes(4, "big")
|
||||
d2['fee_proportional_millionths'] = d['fee_proportional_millionths'].to_bytes(4, "big")
|
||||
d2['flags'] = d['flags'].to_bytes(2, "big")
|
||||
d2['timestamp'] = d['timestamp'].to_bytes(4, "big")
|
||||
return ChannelInfoDirectedPolicy(d2)
|
||||
|
||||
|
||||
class ChannelDB(PrintError):
|
||||
class ChannelDB(JsonDB):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, network):
|
||||
self.network = network
|
||||
|
||||
path = os.path.join(get_headers_dir(network.config), 'channel_db')
|
||||
JsonDB.__init__(self, path)
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self._id_to_channel_info = {}
|
||||
self._channels_for_node = defaultdict(set) # node -> set(short_channel_id)
|
||||
|
||||
self.ca_verifier = LNChanAnnVerifier(network, self)
|
||||
self.network.add_jobs([self.ca_verifier])
|
||||
|
||||
self.load_data()
|
||||
|
||||
def load_data(self):
|
||||
if os.path.exists(self.path):
|
||||
with open(self.path, "r", encoding='utf-8') as f:
|
||||
raw = f.read()
|
||||
self.data = json.loads(raw)
|
||||
channel_infos = self.get('channel_infos', {})
|
||||
for short_channel_id, channel_info_d in channel_infos.items():
|
||||
channel_info = ChannelInfo.from_json(channel_info_d)
|
||||
short_channel_id = bfh(short_channel_id)
|
||||
self.add_verified_channel_info(short_channel_id, channel_info)
|
||||
|
||||
def save_data(self):
|
||||
with self.lock:
|
||||
channel_infos = {}
|
||||
for short_channel_id, channel_info in self._id_to_channel_info.items():
|
||||
channel_infos[bh2u(short_channel_id)] = channel_info
|
||||
self.put('channel_infos', channel_infos)
|
||||
self.write()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._id_to_channel_info)
|
||||
|
||||
def get_channel_info(self, channel_id):
|
||||
def get_channel_info(self, channel_id) -> Optional[ChannelInfo]:
|
||||
return self._id_to_channel_info.get(channel_id, None)
|
||||
|
||||
def get_channels_for_node(self, node_id):
|
||||
"""Returns the set of channels that have node_id as one of the endpoints."""
|
||||
return self._channels_for_node[node_id]
|
||||
|
||||
def add_verified_channel_info(self, short_channel_id: bytes, channel_info: ChannelInfo):
|
||||
with self.lock:
|
||||
self._id_to_channel_info[short_channel_id] = channel_info
|
||||
self._channels_for_node[channel_info.node_id_1].add(short_channel_id)
|
||||
self._channels_for_node[channel_info.node_id_2].add(short_channel_id)
|
||||
|
||||
def on_channel_announcement(self, msg_payload):
|
||||
short_channel_id = msg_payload['short_channel_id']
|
||||
#self.print_error('channel announcement', binascii.hexlify(short_channel_id).decode("ascii"))
|
||||
channel_info = ChannelInfo(msg_payload)
|
||||
if short_channel_id in self._id_to_channel_info:
|
||||
self.print_error("IGNORING CHANNEL ANNOUNCEMENT, WE ALREADY KNOW THIS CHANNEL")
|
||||
return
|
||||
self._id_to_channel_info[short_channel_id] = channel_info
|
||||
self._channels_for_node[channel_info.node_id_1].add(short_channel_id)
|
||||
self._channels_for_node[channel_info.node_id_2].add(short_channel_id)
|
||||
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
|
||||
return
|
||||
channel_info = ChannelInfo(msg_payload)
|
||||
self.ca_verifier.add_new_channel_info(channel_info)
|
||||
|
||||
def on_channel_update(self, msg_payload):
|
||||
short_channel_id = msg_payload['short_channel_id']
|
||||
try:
|
||||
channel_info = self._id_to_channel_info[short_channel_id]
|
||||
except KeyError:
|
||||
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
|
||||
return
|
||||
# try finding channel in verified db
|
||||
channel_info = self._id_to_channel_info.get(short_channel_id, None)
|
||||
if channel_info is None:
|
||||
# try finding channel in pending db
|
||||
channel_info = self.ca_verifier.get_pending_channel_info(short_channel_id)
|
||||
if channel_info is None:
|
||||
# try finding channel in verified db, again
|
||||
# (maybe this is redundant but this should prevent a race..)
|
||||
channel_info = self._id_to_channel_info.get(short_channel_id, None)
|
||||
if channel_info is None:
|
||||
self.print_error("could not find", short_channel_id)
|
||||
else:
|
||||
channel_info.on_channel_update(msg_payload)
|
||||
return
|
||||
channel_info.on_channel_update(msg_payload)
|
||||
|
||||
def remove_channel(self, short_channel_id):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user