ln gossip: don't put own channels into db; always pass them to fn calls
Previously we would put fake chan announcement and fake outgoing chan upd for own channels into db (to make path finding work). See Peer.add_own_channel(). Now, instead of above, we pass a "my_channels" param to the relevant ChannelDB methods.
This commit is contained in:
@@ -39,9 +39,11 @@ from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enab
|
||||
from .logging import Logger
|
||||
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID
|
||||
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
|
||||
from .lnmsg import decode_msg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .network import Network
|
||||
from .lnchannel import Channel
|
||||
|
||||
|
||||
class UnknownEvenFeatureBits(Exception): pass
|
||||
@@ -63,7 +65,7 @@ class ChannelInfo(NamedTuple):
|
||||
capacity_sat: Optional[int]
|
||||
|
||||
@staticmethod
|
||||
def from_msg(payload):
|
||||
def from_msg(payload: dict) -> 'ChannelInfo':
|
||||
features = int.from_bytes(payload['features'], 'big')
|
||||
validate_features(features)
|
||||
channel_id = payload['short_channel_id']
|
||||
@@ -78,6 +80,11 @@ class ChannelInfo(NamedTuple):
|
||||
capacity_sat = capacity_sat
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_raw_msg(raw: bytes) -> 'ChannelInfo':
|
||||
payload_dict = decode_msg(raw)[1]
|
||||
return ChannelInfo.from_msg(payload_dict)
|
||||
|
||||
|
||||
class Policy(NamedTuple):
|
||||
key: bytes
|
||||
@@ -91,7 +98,7 @@ class Policy(NamedTuple):
|
||||
timestamp: int
|
||||
|
||||
@staticmethod
|
||||
def from_msg(payload):
|
||||
def from_msg(payload: dict) -> 'Policy':
|
||||
return Policy(
|
||||
key = payload['short_channel_id'] + payload['start_node'],
|
||||
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"),
|
||||
@@ -248,11 +255,11 @@ class ChannelDB(SqlDB):
|
||||
self.ca_verifier = LNChannelVerifier(network, self)
|
||||
# initialized in load_data
|
||||
self._channels = {} # type: Dict[bytes, ChannelInfo]
|
||||
self._policies = {}
|
||||
self._policies = {} # type: Dict[Tuple[bytes, bytes], Policy] # (node_id, scid) -> Policy
|
||||
self._nodes = {}
|
||||
# node_id -> (host, port, ts)
|
||||
self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]]
|
||||
self._channels_for_node = defaultdict(set)
|
||||
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
|
||||
self.data_loaded = asyncio.Event()
|
||||
self.network = network # only for callback
|
||||
|
||||
@@ -495,17 +502,6 @@ class ChannelDB(SqlDB):
|
||||
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
|
||||
self.update_counts()
|
||||
|
||||
def get_routing_policy_for_channel(self, start_node_id: bytes,
|
||||
short_channel_id: bytes) -> Optional[Policy]:
|
||||
if not start_node_id or not short_channel_id: return None
|
||||
channel_info = self.get_channel_info(short_channel_id)
|
||||
if channel_info is not None:
|
||||
return self.get_policy_for_node(short_channel_id, start_node_id)
|
||||
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
|
||||
if not msg:
|
||||
return None
|
||||
return Policy.from_msg(msg) # won't actually be written to DB
|
||||
|
||||
def get_old_policies(self, delta):
|
||||
now = int(time.time())
|
||||
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
|
||||
@@ -587,12 +583,56 @@ class ChannelDB(SqlDB):
|
||||
out.add(short_channel_id)
|
||||
self.logger.info(f'semi-orphaned: {len(out)}')
|
||||
|
||||
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
|
||||
return self._policies.get((node_id, short_channel_id))
|
||||
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']:
|
||||
channel_info = self.get_channel_info(short_channel_id)
|
||||
if channel_info is not None: # publicly announced channel
|
||||
policy = self._policies.get((node_id, short_channel_id))
|
||||
if policy:
|
||||
return policy
|
||||
else: # private channel
|
||||
chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id))
|
||||
if chan_upd_dict:
|
||||
return Policy.from_msg(chan_upd_dict)
|
||||
# check if it's one of our own channels
|
||||
if not my_channels:
|
||||
return
|
||||
chan = my_channels.get(short_channel_id) # type: Optional[Channel]
|
||||
if not chan:
|
||||
return
|
||||
if node_id == chan.node_id: # incoming direction (to us)
|
||||
remote_update_raw = chan.get_remote_update()
|
||||
if not remote_update_raw:
|
||||
return
|
||||
now = int(time.time())
|
||||
remote_update_decoded = decode_msg(remote_update_raw)[1]
|
||||
remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
|
||||
remote_update_decoded['start_node'] = node_id
|
||||
return Policy.from_msg(remote_update_decoded)
|
||||
elif node_id == chan.get_local_pubkey(): # outgoing direction (from us)
|
||||
local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1]
|
||||
local_update_decoded['start_node'] = node_id
|
||||
return Policy.from_msg(local_update_decoded)
|
||||
|
||||
def get_channel_info(self, channel_id: bytes) -> ChannelInfo:
|
||||
return self._channels.get(channel_id)
|
||||
def get_channel_info(self, short_channel_id: bytes, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[ChannelInfo]:
|
||||
ret = self._channels.get(short_channel_id)
|
||||
if ret:
|
||||
return ret
|
||||
# check if it's one of our own channels
|
||||
if not my_channels:
|
||||
return
|
||||
chan = my_channels.get(short_channel_id) # type: Optional[Channel]
|
||||
ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs())
|
||||
return ci._replace(capacity_sat=chan.constraints.capacity)
|
||||
|
||||
def get_channels_for_node(self, node_id) -> Set[bytes]:
|
||||
"""Returns the set of channels that have node_id as one of the endpoints."""
|
||||
return self._channels_for_node.get(node_id) or set()
|
||||
def get_channels_for_node(self, node_id: bytes, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]:
|
||||
"""Returns the set of short channel IDs where node_id is one of the channel participants."""
|
||||
relevant_channels = self._channels_for_node.get(node_id) or set()
|
||||
relevant_channels = set(relevant_channels) # copy
|
||||
# add our own channels # TODO maybe slow?
|
||||
for chan in (my_channels.values() or []):
|
||||
if node_id in (chan.node_id, chan.get_local_pubkey()):
|
||||
relevant_channels.add(chan.short_channel_id)
|
||||
return relevant_channels
|
||||
|
||||
Reference in New Issue
Block a user