ChannelDB: avoid duplicate (host,port) entries in ChannelDB._addresses
before: node_id -> set of (host, port, ts) after: node_id -> NetAddress -> timestamp Look at e.g. add_recent_peer; we only want to store the last connection time, not all of them.
This commit is contained in:
@@ -34,6 +34,7 @@ import asyncio
|
||||
import threading
|
||||
from enum import IntEnum
|
||||
|
||||
from aiorpcx import NetAddress
|
||||
|
||||
from .sql_db import SqlDB, sql
|
||||
from . import constants, util
|
||||
@@ -53,14 +54,6 @@ FLAG_DISABLE = 1 << 1
|
||||
FLAG_DIRECTION = 1 << 0
|
||||
|
||||
|
||||
class NodeAddress(NamedTuple):
|
||||
"""Holds address information of Lightning nodes
|
||||
and how up to date this info is."""
|
||||
host: str
|
||||
port: int
|
||||
timestamp: int
|
||||
|
||||
|
||||
class ChannelInfo(NamedTuple):
|
||||
short_channel_id: ShortChannelID
|
||||
node1_id: bytes
|
||||
@@ -295,8 +288,8 @@ class ChannelDB(SqlDB):
|
||||
self._channels = {} # type: Dict[ShortChannelID, ChannelInfo]
|
||||
self._policies = {} # type: Dict[Tuple[bytes, ShortChannelID], Policy] # (node_id, scid) -> Policy
|
||||
self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo
|
||||
# node_id -> (host, port, ts)
|
||||
self._addresses = defaultdict(set) # type: Dict[bytes, Set[NodeAddress]]
|
||||
# node_id -> NetAddress -> timestamp
|
||||
self._addresses = defaultdict(dict) # type: Dict[bytes, Dict[NetAddress, int]]
|
||||
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
|
||||
self._recent_peers = [] # type: List[bytes] # list of node_ids
|
||||
self._chans_with_0_policies = set() # type: Set[ShortChannelID]
|
||||
@@ -321,7 +314,7 @@ class ChannelDB(SqlDB):
|
||||
now = int(time.time())
|
||||
node_id = peer.pubkey
|
||||
with self.lock:
|
||||
self._addresses[node_id].add(NodeAddress(peer.host, peer.port, now))
|
||||
self._addresses[node_id][peer.net_addr()] = now
|
||||
# list is ordered
|
||||
if node_id in self._recent_peers:
|
||||
self._recent_peers.remove(node_id)
|
||||
@@ -336,12 +329,12 @@ class ChannelDB(SqlDB):
|
||||
|
||||
def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
|
||||
"""Returns latest address we successfully connected to, for given node."""
|
||||
r = self._addresses.get(node_id)
|
||||
if not r:
|
||||
addr_to_ts = self._addresses.get(node_id)
|
||||
if not addr_to_ts:
|
||||
return None
|
||||
addr = sorted(list(r), key=lambda x: x.timestamp, reverse=True)[0]
|
||||
addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0]
|
||||
try:
|
||||
return LNPeerAddr(addr.host, addr.port, node_id)
|
||||
return LNPeerAddr(str(addr.host), addr.port, node_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@@ -583,7 +576,8 @@ class ChannelDB(SqlDB):
|
||||
self._db_save_node_info(node_id, msg_payload['raw'])
|
||||
with self.lock:
|
||||
for addr in node_addresses:
|
||||
self._addresses[node_id].add(NodeAddress(addr.host, addr.port, 0))
|
||||
net_addr = NetAddress(addr.host, addr.port)
|
||||
self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0
|
||||
self._db_save_node_addresses(node_addresses)
|
||||
|
||||
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
|
||||
@@ -634,8 +628,13 @@ class ChannelDB(SqlDB):
|
||||
# delete from database
|
||||
self._db_delete_channel(short_channel_id)
|
||||
|
||||
def get_node_addresses(self, node_id):
|
||||
return self._addresses.get(node_id)
|
||||
def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]:
|
||||
"""Returns list of (host, port, timestamp)."""
|
||||
addr_to_ts = self._addresses.get(node_id)
|
||||
if not addr_to_ts:
|
||||
return []
|
||||
return [(str(net_addr.host), net_addr.port, ts)
|
||||
for net_addr, ts in addr_to_ts.items()]
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
@@ -643,17 +642,19 @@ class ChannelDB(SqlDB):
|
||||
if self.data_loaded.is_set():
|
||||
return
|
||||
# Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow.
|
||||
# I believe lnmsg (and lightning.json) will need a rewrite anyway, so instead of tweaking
|
||||
# load_data() here, that should be done. see #6006
|
||||
c = self.conn.cursor()
|
||||
c.execute("""SELECT * FROM address""")
|
||||
for x in c:
|
||||
node_id, host, port, timestamp = x
|
||||
self._addresses[node_id].add(NodeAddress(str(host), int(port), int(timestamp or 0)))
|
||||
try:
|
||||
net_addr = NetAddress(host, port)
|
||||
except Exception:
|
||||
continue
|
||||
self._addresses[node_id][net_addr] = int(timestamp or 0)
|
||||
def newest_ts_for_node_id(node_id):
|
||||
newest_ts = 0
|
||||
for addr in self._addresses[node_id]:
|
||||
newest_ts = max(newest_ts, addr.timestamp)
|
||||
for addr, ts in self._addresses[node_id].items():
|
||||
newest_ts = max(newest_ts, ts)
|
||||
return newest_ts
|
||||
sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True)
|
||||
self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
|
||||
@@ -791,7 +792,10 @@ class ChannelDB(SqlDB):
|
||||
graph['nodes'].append(
|
||||
nodeinfo._asdict(),
|
||||
)
|
||||
graph['nodes'][-1]['addresses'] = [addr._asdict() for addr in self._addresses[pk]]
|
||||
graph['nodes'][-1]['addresses'] = [
|
||||
{'host': str(addr.host), 'port': addr.port, 'timestamp': ts}
|
||||
for addr, ts in self._addresses[pk].items()
|
||||
]
|
||||
|
||||
# gather channels
|
||||
for cid, channelinfo in self._channels.items():
|
||||
|
||||
@@ -1106,6 +1106,7 @@ def derive_payment_secret_from_payment_preimage(payment_preimage: bytes) -> byte
|
||||
|
||||
|
||||
class LNPeerAddr:
|
||||
# note: while not programmatically enforced, this class is meant to be *immutable*
|
||||
|
||||
def __init__(self, host: str, port: int, pubkey: bytes):
|
||||
assert isinstance(host, str), repr(host)
|
||||
@@ -1120,7 +1121,7 @@ class LNPeerAddr:
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.pubkey = pubkey
|
||||
self._net_addr_str = str(net_addr)
|
||||
self._net_addr = net_addr
|
||||
|
||||
def __str__(self):
|
||||
return '{}@{}'.format(self.pubkey.hex(), self.net_addr_str())
|
||||
@@ -1128,8 +1129,11 @@ class LNPeerAddr:
|
||||
def __repr__(self):
|
||||
return f'<LNPeerAddr host={self.host} port={self.port} pubkey={self.pubkey.hex()}>'
|
||||
|
||||
def net_addr(self) -> NetAddress:
|
||||
return self._net_addr
|
||||
|
||||
def net_addr_str(self) -> str:
|
||||
return self._net_addr_str
|
||||
return str(self._net_addr)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, LNPeerAddr):
|
||||
|
||||
Reference in New Issue
Block a user