1
0

create class for ShortChannelID and use it

This commit is contained in:
SomberNight
2019-09-06 18:09:05 +02:00
parent 251db638af
commit 509df9ddaf
8 changed files with 110 additions and 76 deletions

View File

@@ -37,7 +37,7 @@ from .sql_db import SqlDB, sql
from . import constants
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
from .logging import Logger
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
if TYPE_CHECKING:
@@ -57,10 +57,10 @@ FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0
class ChannelInfo(NamedTuple):
short_channel_id: bytes
short_channel_id: ShortChannelID
node1_id: bytes
node2_id: bytes
capacity_sat: int
capacity_sat: Optional[int]
@staticmethod
def from_msg(payload):
@@ -72,10 +72,11 @@ class ChannelInfo(NamedTuple):
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
capacity_sat = None
return ChannelInfo(
short_channel_id = channel_id,
short_channel_id = ShortChannelID.normalize(channel_id),
node1_id = node_id_1,
node2_id = node_id_2,
capacity_sat = capacity_sat)
capacity_sat = capacity_sat
)
class Policy(NamedTuple):
@@ -107,8 +108,8 @@ class Policy(NamedTuple):
return self.channel_flags & FLAG_DISABLE
@property
def short_channel_id(self):
return self.key[0:8]
def short_channel_id(self) -> ShortChannelID:
return ShortChannelID.normalize(self.key[0:8])
@property
def start_node(self):
@@ -290,7 +291,7 @@ class ChannelDB(SqlDB):
msg_payloads = [msg_payloads]
added = 0
for msg in msg_payloads:
short_channel_id = msg['short_channel_id']
short_channel_id = ShortChannelID(msg['short_channel_id'])
if short_channel_id in self._channels:
continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
@@ -339,7 +340,7 @@ class ChannelDB(SqlDB):
known = []
now = int(time.time())
for payload in payloads:
short_channel_id = payload['short_channel_id']
short_channel_id = ShortChannelID(payload['short_channel_id'])
timestamp = int.from_bytes(payload['timestamp'], "big")
if max_age and now - timestamp > max_age:
expired.append(payload)
@@ -357,7 +358,7 @@ class ChannelDB(SqlDB):
for payload in known:
timestamp = int.from_bytes(payload['timestamp'], "big")
start_node = payload['start_node']
short_channel_id = payload['short_channel_id']
short_channel_id = ShortChannelID(payload['short_channel_id'])
key = (start_node, short_channel_id)
old_policy = self._policies.get(key)
if old_policy and timestamp <= old_policy.timestamp:
@@ -434,11 +435,11 @@ class ChannelDB(SqlDB):
def verify_channel_update(self, payload):
short_channel_id = payload['short_channel_id']
scid = format_short_channel_id(short_channel_id)
short_channel_id = ShortChannelID(short_channel_id)
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
raise Exception('wrong chain hash')
if not verify_sig_for_channel_update(payload, payload['start_node']):
raise Exception(f'failed verifying channel update for {scid}')
raise Exception(f'failed verifying channel update for {short_channel_id}')
def add_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict:
@@ -510,11 +511,11 @@ class ChannelDB(SqlDB):
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
if not verify_sig_for_channel_update(msg_payload, start_node_id):
return # ignore
short_channel_id = msg_payload['short_channel_id']
short_channel_id = ShortChannelID(msg_payload['short_channel_id'])
msg_payload['start_node'] = start_node_id
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
def remove_channel(self, short_channel_id):
def remove_channel(self, short_channel_id: ShortChannelID):
channel_info = self._channels.pop(short_channel_id, None)
if channel_info:
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
@@ -533,6 +534,7 @@ class ChannelDB(SqlDB):
self._addresses[node_id].add((str(host), int(port), int(timestamp or 0)))
c.execute("""SELECT * FROM channel_info""")
for x in c:
x = (ShortChannelID.normalize(x[0]), *x[1:])
ci = ChannelInfo(*x)
self._channels[ci.short_channel_id] = ci
c.execute("""SELECT * FROM node_info""")