1
0

persist channel db on disk. verify channel gossip sigs.

This commit is contained in:
SomberNight
2018-07-23 20:49:44 +02:00
committed by ThomasV
parent c1d1826014
commit a5b44d25b0
8 changed files with 396 additions and 55 deletions

View File

@@ -30,8 +30,11 @@ import sys
import binascii
import hashlib
import hmac
import os
import json
import threading
from collections import namedtuple, defaultdict
from typing import Sequence, Union, Tuple
from typing import Sequence, Union, Tuple, Optional
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
from cryptography.hazmat.backends import default_backend
@@ -39,9 +42,12 @@ from cryptography.hazmat.backends import default_backend
from . import bitcoin
from . import ecc
from . import crypto
from . import constants
from .crypto import sha256
from .util import PrintError, bh2u, profiler, xor_bytes
from .util import PrintError, bh2u, profiler, xor_bytes, get_headers_dir, bfh
from .lnutil import get_ecdh
from .storage import JsonDB
from .lnchanannverifier import LNChanAnnVerifier, verify_sig_for_channel_update
class ChannelInfo(PrintError):
@@ -54,23 +60,71 @@ class ChannelInfo(PrintError):
assert type(self.node_id_2) is bytes
assert list(sorted([self.node_id_1, self.node_id_2])) == [self.node_id_1, self.node_id_2]
self.features_len = channel_announcement_payload['len']
self.features = channel_announcement_payload['features']
self.bitcoin_key_1 = channel_announcement_payload['bitcoin_key_1']
self.bitcoin_key_2 = channel_announcement_payload['bitcoin_key_2']
# this field does not get persisted
self.msg_payload = channel_announcement_payload
self.capacity_sat = None
self.policy_node1 = None
self.policy_node2 = None
def to_json(self) -> dict:
d = {}
d['short_channel_id'] = bh2u(self.channel_id)
d['node_id_1'] = bh2u(self.node_id_1)
d['node_id_2'] = bh2u(self.node_id_2)
d['len'] = bh2u(self.features_len)
d['features'] = bh2u(self.features)
d['bitcoin_key_1'] = bh2u(self.bitcoin_key_1)
d['bitcoin_key_2'] = bh2u(self.bitcoin_key_2)
d['policy_node1'] = self.policy_node1
d['policy_node2'] = self.policy_node2
d['capacity_sat'] = self.capacity_sat
return d
@classmethod
def from_json(cls, d: dict):
d2 = {}
d2['short_channel_id'] = bfh(d['short_channel_id'])
d2['node_id_1'] = bfh(d['node_id_1'])
d2['node_id_2'] = bfh(d['node_id_2'])
d2['len'] = bfh(d['len'])
d2['features'] = bfh(d['features'])
d2['bitcoin_key_1'] = bfh(d['bitcoin_key_1'])
d2['bitcoin_key_2'] = bfh(d['bitcoin_key_2'])
ci = ChannelInfo(d2)
ci.capacity_sat = d['capacity_sat']
ci.policy_node1 = ChannelInfoDirectedPolicy.from_json(d['policy_node1'])
ci.policy_node2 = ChannelInfoDirectedPolicy.from_json(d['policy_node2'])
return ci
def set_capacity(self, capacity):
# TODO call this after looking up UTXO for funding txn on chain
self.capacity_sat = capacity
def on_channel_update(self, msg_payload):
assert self.channel_id == msg_payload['short_channel_id']
flags = int.from_bytes(msg_payload['flags'], 'big')
direction = flags & 1
new_policy = ChannelInfoDirectedPolicy(msg_payload)
if direction == 0:
self.policy_node1 = ChannelInfoDirectedPolicy(msg_payload)
old_policy = self.policy_node1
node_id = self.node_id_1
else:
self.policy_node2 = ChannelInfoDirectedPolicy(msg_payload)
#self.print_error('channel update', binascii.hexlify(self.channel_id).decode("ascii"), flags)
old_policy = self.policy_node2
node_id = self.node_id_2
if old_policy and old_policy.timestamp >= new_policy.timestamp:
return # ignore
if not verify_sig_for_channel_update(msg_payload, node_id):
return # ignore
# save new policy
if direction == 0:
self.policy_node1 = new_policy
else:
self.policy_node2 = new_policy
def get_policy_for_node(self, node_id):
if node_id == self.node_id_1:
@@ -84,51 +138,121 @@ class ChannelInfo(PrintError):
class ChannelInfoDirectedPolicy:
def __init__(self, channel_update_payload):
self.cltv_expiry_delta = channel_update_payload['cltv_expiry_delta']
self.htlc_minimum_msat = channel_update_payload['htlc_minimum_msat']
self.fee_base_msat = channel_update_payload['fee_base_msat']
self.fee_proportional_millionths = channel_update_payload['fee_proportional_millionths']
self.cltv_expiry_delta = int.from_bytes(self.cltv_expiry_delta, "big")
self.htlc_minimum_msat = int.from_bytes(self.htlc_minimum_msat, "big")
self.fee_base_msat = int.from_bytes(self.fee_base_msat, "big")
self.fee_proportional_millionths = int.from_bytes(self.fee_proportional_millionths, "big")
cltv_expiry_delta = channel_update_payload['cltv_expiry_delta']
htlc_minimum_msat = channel_update_payload['htlc_minimum_msat']
fee_base_msat = channel_update_payload['fee_base_msat']
fee_proportional_millionths = channel_update_payload['fee_proportional_millionths']
flags = channel_update_payload['flags']
timestamp = channel_update_payload['timestamp']
self.cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
self.htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
self.fee_base_msat = int.from_bytes(fee_base_msat, "big")
self.fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
self.flags = int.from_bytes(flags, "big")
self.timestamp = int.from_bytes(timestamp, "big")
def to_json(self) -> dict:
d = {}
d['cltv_expiry_delta'] = self.cltv_expiry_delta
d['htlc_minimum_msat'] = self.htlc_minimum_msat
d['fee_base_msat'] = self.fee_base_msat
d['fee_proportional_millionths'] = self.fee_proportional_millionths
d['flags'] = self.flags
d['timestamp'] = self.timestamp
return d
@classmethod
def from_json(cls, d: dict):
if d is None: return None
d2 = {}
d2['cltv_expiry_delta'] = d['cltv_expiry_delta'].to_bytes(2, "big")
d2['htlc_minimum_msat'] = d['htlc_minimum_msat'].to_bytes(8, "big")
d2['fee_base_msat'] = d['fee_base_msat'].to_bytes(4, "big")
d2['fee_proportional_millionths'] = d['fee_proportional_millionths'].to_bytes(4, "big")
d2['flags'] = d['flags'].to_bytes(2, "big")
d2['timestamp'] = d['timestamp'].to_bytes(4, "big")
return ChannelInfoDirectedPolicy(d2)
class ChannelDB(PrintError):
class ChannelDB(JsonDB):
def __init__(self):
def __init__(self, network):
self.network = network
path = os.path.join(get_headers_dir(network.config), 'channel_db')
JsonDB.__init__(self, path)
self.lock = threading.Lock()
self._id_to_channel_info = {}
self._channels_for_node = defaultdict(set) # node -> set(short_channel_id)
self.ca_verifier = LNChanAnnVerifier(network, self)
self.network.add_jobs([self.ca_verifier])
self.load_data()
def load_data(self):
if os.path.exists(self.path):
with open(self.path, "r", encoding='utf-8') as f:
raw = f.read()
self.data = json.loads(raw)
channel_infos = self.get('channel_infos', {})
for short_channel_id, channel_info_d in channel_infos.items():
channel_info = ChannelInfo.from_json(channel_info_d)
short_channel_id = bfh(short_channel_id)
self.add_verified_channel_info(short_channel_id, channel_info)
def save_data(self):
with self.lock:
channel_infos = {}
for short_channel_id, channel_info in self._id_to_channel_info.items():
channel_infos[bh2u(short_channel_id)] = channel_info
self.put('channel_infos', channel_infos)
self.write()
def __len__(self):
return len(self._id_to_channel_info)
def get_channel_info(self, channel_id):
def get_channel_info(self, channel_id) -> Optional[ChannelInfo]:
return self._id_to_channel_info.get(channel_id, None)
def get_channels_for_node(self, node_id):
"""Returns the set of channels that have node_id as one of the endpoints."""
return self._channels_for_node[node_id]
def add_verified_channel_info(self, short_channel_id: bytes, channel_info: ChannelInfo):
with self.lock:
self._id_to_channel_info[short_channel_id] = channel_info
self._channels_for_node[channel_info.node_id_1].add(short_channel_id)
self._channels_for_node[channel_info.node_id_2].add(short_channel_id)
def on_channel_announcement(self, msg_payload):
short_channel_id = msg_payload['short_channel_id']
#self.print_error('channel announcement', binascii.hexlify(short_channel_id).decode("ascii"))
channel_info = ChannelInfo(msg_payload)
if short_channel_id in self._id_to_channel_info:
self.print_error("IGNORING CHANNEL ANNOUNCEMENT, WE ALREADY KNOW THIS CHANNEL")
return
self._id_to_channel_info[short_channel_id] = channel_info
self._channels_for_node[channel_info.node_id_1].add(short_channel_id)
self._channels_for_node[channel_info.node_id_2].add(short_channel_id)
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
return
channel_info = ChannelInfo(msg_payload)
self.ca_verifier.add_new_channel_info(channel_info)
def on_channel_update(self, msg_payload):
short_channel_id = msg_payload['short_channel_id']
try:
channel_info = self._id_to_channel_info[short_channel_id]
except KeyError:
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
return
# try finding channel in verified db
channel_info = self._id_to_channel_info.get(short_channel_id, None)
if channel_info is None:
# try finding channel in pending db
channel_info = self.ca_verifier.get_pending_channel_info(short_channel_id)
if channel_info is None:
# try finding channel in verified db, again
# (maybe this is redundant but this should prevent a race..)
channel_info = self._id_to_channel_info.get(short_channel_id, None)
if channel_info is None:
self.print_error("could not find", short_channel_id)
else:
channel_info.on_channel_update(msg_payload)
return
channel_info.on_channel_update(msg_payload)
def remove_channel(self, short_channel_id):
try: