create class for ShortChannelID and use it
This commit is contained in:
@@ -37,7 +37,7 @@ from .sql_db import SqlDB, sql
|
||||
from . import constants
|
||||
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
|
||||
from .logging import Logger
|
||||
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id
|
||||
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID
|
||||
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -57,10 +57,10 @@ FLAG_DISABLE = 1 << 1
|
||||
FLAG_DIRECTION = 1 << 0
|
||||
|
||||
class ChannelInfo(NamedTuple):
|
||||
short_channel_id: bytes
|
||||
short_channel_id: ShortChannelID
|
||||
node1_id: bytes
|
||||
node2_id: bytes
|
||||
capacity_sat: int
|
||||
capacity_sat: Optional[int]
|
||||
|
||||
@staticmethod
|
||||
def from_msg(payload):
|
||||
@@ -72,10 +72,11 @@ class ChannelInfo(NamedTuple):
|
||||
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
|
||||
capacity_sat = None
|
||||
return ChannelInfo(
|
||||
short_channel_id = channel_id,
|
||||
short_channel_id = ShortChannelID.normalize(channel_id),
|
||||
node1_id = node_id_1,
|
||||
node2_id = node_id_2,
|
||||
capacity_sat = capacity_sat)
|
||||
capacity_sat = capacity_sat
|
||||
)
|
||||
|
||||
|
||||
class Policy(NamedTuple):
|
||||
@@ -107,8 +108,8 @@ class Policy(NamedTuple):
|
||||
return self.channel_flags & FLAG_DISABLE
|
||||
|
||||
@property
|
||||
def short_channel_id(self):
|
||||
return self.key[0:8]
|
||||
def short_channel_id(self) -> ShortChannelID:
|
||||
return ShortChannelID.normalize(self.key[0:8])
|
||||
|
||||
@property
|
||||
def start_node(self):
|
||||
@@ -290,7 +291,7 @@ class ChannelDB(SqlDB):
|
||||
msg_payloads = [msg_payloads]
|
||||
added = 0
|
||||
for msg in msg_payloads:
|
||||
short_channel_id = msg['short_channel_id']
|
||||
short_channel_id = ShortChannelID(msg['short_channel_id'])
|
||||
if short_channel_id in self._channels:
|
||||
continue
|
||||
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
|
||||
@@ -339,7 +340,7 @@ class ChannelDB(SqlDB):
|
||||
known = []
|
||||
now = int(time.time())
|
||||
for payload in payloads:
|
||||
short_channel_id = payload['short_channel_id']
|
||||
short_channel_id = ShortChannelID(payload['short_channel_id'])
|
||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||
if max_age and now - timestamp > max_age:
|
||||
expired.append(payload)
|
||||
@@ -357,7 +358,7 @@ class ChannelDB(SqlDB):
|
||||
for payload in known:
|
||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||
start_node = payload['start_node']
|
||||
short_channel_id = payload['short_channel_id']
|
||||
short_channel_id = ShortChannelID(payload['short_channel_id'])
|
||||
key = (start_node, short_channel_id)
|
||||
old_policy = self._policies.get(key)
|
||||
if old_policy and timestamp <= old_policy.timestamp:
|
||||
@@ -434,11 +435,11 @@ class ChannelDB(SqlDB):
|
||||
|
||||
def verify_channel_update(self, payload):
|
||||
short_channel_id = payload['short_channel_id']
|
||||
scid = format_short_channel_id(short_channel_id)
|
||||
short_channel_id = ShortChannelID(short_channel_id)
|
||||
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
|
||||
raise Exception('wrong chain hash')
|
||||
if not verify_sig_for_channel_update(payload, payload['start_node']):
|
||||
raise Exception(f'failed verifying channel update for {scid}')
|
||||
raise Exception(f'failed verifying channel update for {short_channel_id}')
|
||||
|
||||
def add_node_announcement(self, msg_payloads):
|
||||
if type(msg_payloads) is dict:
|
||||
@@ -510,11 +511,11 @@ class ChannelDB(SqlDB):
|
||||
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
|
||||
if not verify_sig_for_channel_update(msg_payload, start_node_id):
|
||||
return # ignore
|
||||
short_channel_id = msg_payload['short_channel_id']
|
||||
short_channel_id = ShortChannelID(msg_payload['short_channel_id'])
|
||||
msg_payload['start_node'] = start_node_id
|
||||
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
|
||||
|
||||
def remove_channel(self, short_channel_id):
|
||||
def remove_channel(self, short_channel_id: ShortChannelID):
|
||||
channel_info = self._channels.pop(short_channel_id, None)
|
||||
if channel_info:
|
||||
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
|
||||
@@ -533,6 +534,7 @@ class ChannelDB(SqlDB):
|
||||
self._addresses[node_id].add((str(host), int(port), int(timestamp or 0)))
|
||||
c.execute("""SELECT * FROM channel_info""")
|
||||
for x in c:
|
||||
x = (ShortChannelID.normalize(x[0]), *x[1:])
|
||||
ci = ChannelInfo(*x)
|
||||
self._channels[ci.short_channel_id] = ci
|
||||
c.execute("""SELECT * FROM node_info""")
|
||||
|
||||
Reference in New Issue
Block a user