ChannelDB: avoid duplicate (host,port) entries in ChannelDB._addresses
before: node_id -> set of (host, port, ts) after: node_id -> NetAddress -> timestamp Look at e.g. add_recent_peer; we only want to store the last connection time, not all of them.
This commit is contained in:
@@ -34,6 +34,7 @@ import asyncio
|
|||||||
import threading
|
import threading
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
|
||||||
|
from aiorpcx import NetAddress
|
||||||
|
|
||||||
from .sql_db import SqlDB, sql
|
from .sql_db import SqlDB, sql
|
||||||
from . import constants, util
|
from . import constants, util
|
||||||
@@ -53,14 +54,6 @@ FLAG_DISABLE = 1 << 1
|
|||||||
FLAG_DIRECTION = 1 << 0
|
FLAG_DIRECTION = 1 << 0
|
||||||
|
|
||||||
|
|
||||||
class NodeAddress(NamedTuple):
|
|
||||||
"""Holds address information of Lightning nodes
|
|
||||||
and how up to date this info is."""
|
|
||||||
host: str
|
|
||||||
port: int
|
|
||||||
timestamp: int
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelInfo(NamedTuple):
|
class ChannelInfo(NamedTuple):
|
||||||
short_channel_id: ShortChannelID
|
short_channel_id: ShortChannelID
|
||||||
node1_id: bytes
|
node1_id: bytes
|
||||||
@@ -295,8 +288,8 @@ class ChannelDB(SqlDB):
|
|||||||
self._channels = {} # type: Dict[ShortChannelID, ChannelInfo]
|
self._channels = {} # type: Dict[ShortChannelID, ChannelInfo]
|
||||||
self._policies = {} # type: Dict[Tuple[bytes, ShortChannelID], Policy] # (node_id, scid) -> Policy
|
self._policies = {} # type: Dict[Tuple[bytes, ShortChannelID], Policy] # (node_id, scid) -> Policy
|
||||||
self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo
|
self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo
|
||||||
# node_id -> (host, port, ts)
|
# node_id -> NetAddress -> timestamp
|
||||||
self._addresses = defaultdict(set) # type: Dict[bytes, Set[NodeAddress]]
|
self._addresses = defaultdict(dict) # type: Dict[bytes, Dict[NetAddress, int]]
|
||||||
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
|
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
|
||||||
self._recent_peers = [] # type: List[bytes] # list of node_ids
|
self._recent_peers = [] # type: List[bytes] # list of node_ids
|
||||||
self._chans_with_0_policies = set() # type: Set[ShortChannelID]
|
self._chans_with_0_policies = set() # type: Set[ShortChannelID]
|
||||||
@@ -321,7 +314,7 @@ class ChannelDB(SqlDB):
|
|||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
node_id = peer.pubkey
|
node_id = peer.pubkey
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self._addresses[node_id].add(NodeAddress(peer.host, peer.port, now))
|
self._addresses[node_id][peer.net_addr()] = now
|
||||||
# list is ordered
|
# list is ordered
|
||||||
if node_id in self._recent_peers:
|
if node_id in self._recent_peers:
|
||||||
self._recent_peers.remove(node_id)
|
self._recent_peers.remove(node_id)
|
||||||
@@ -336,12 +329,12 @@ class ChannelDB(SqlDB):
|
|||||||
|
|
||||||
def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
|
def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
|
||||||
"""Returns latest address we successfully connected to, for given node."""
|
"""Returns latest address we successfully connected to, for given node."""
|
||||||
r = self._addresses.get(node_id)
|
addr_to_ts = self._addresses.get(node_id)
|
||||||
if not r:
|
if not addr_to_ts:
|
||||||
return None
|
return None
|
||||||
addr = sorted(list(r), key=lambda x: x.timestamp, reverse=True)[0]
|
addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0]
|
||||||
try:
|
try:
|
||||||
return LNPeerAddr(addr.host, addr.port, node_id)
|
return LNPeerAddr(str(addr.host), addr.port, node_id)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -583,7 +576,8 @@ class ChannelDB(SqlDB):
|
|||||||
self._db_save_node_info(node_id, msg_payload['raw'])
|
self._db_save_node_info(node_id, msg_payload['raw'])
|
||||||
with self.lock:
|
with self.lock:
|
||||||
for addr in node_addresses:
|
for addr in node_addresses:
|
||||||
self._addresses[node_id].add(NodeAddress(addr.host, addr.port, 0))
|
net_addr = NetAddress(addr.host, addr.port)
|
||||||
|
self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0
|
||||||
self._db_save_node_addresses(node_addresses)
|
self._db_save_node_addresses(node_addresses)
|
||||||
|
|
||||||
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
|
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
|
||||||
@@ -634,8 +628,13 @@ class ChannelDB(SqlDB):
|
|||||||
# delete from database
|
# delete from database
|
||||||
self._db_delete_channel(short_channel_id)
|
self._db_delete_channel(short_channel_id)
|
||||||
|
|
||||||
def get_node_addresses(self, node_id):
|
def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]:
|
||||||
return self._addresses.get(node_id)
|
"""Returns list of (host, port, timestamp)."""
|
||||||
|
addr_to_ts = self._addresses.get(node_id)
|
||||||
|
if not addr_to_ts:
|
||||||
|
return []
|
||||||
|
return [(str(net_addr.host), net_addr.port, ts)
|
||||||
|
for net_addr, ts in addr_to_ts.items()]
|
||||||
|
|
||||||
@sql
|
@sql
|
||||||
@profiler
|
@profiler
|
||||||
@@ -643,17 +642,19 @@ class ChannelDB(SqlDB):
|
|||||||
if self.data_loaded.is_set():
|
if self.data_loaded.is_set():
|
||||||
return
|
return
|
||||||
# Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow.
|
# Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow.
|
||||||
# I believe lnmsg (and lightning.json) will need a rewrite anyway, so instead of tweaking
|
|
||||||
# load_data() here, that should be done. see #6006
|
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
c.execute("""SELECT * FROM address""")
|
c.execute("""SELECT * FROM address""")
|
||||||
for x in c:
|
for x in c:
|
||||||
node_id, host, port, timestamp = x
|
node_id, host, port, timestamp = x
|
||||||
self._addresses[node_id].add(NodeAddress(str(host), int(port), int(timestamp or 0)))
|
try:
|
||||||
|
net_addr = NetAddress(host, port)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
self._addresses[node_id][net_addr] = int(timestamp or 0)
|
||||||
def newest_ts_for_node_id(node_id):
|
def newest_ts_for_node_id(node_id):
|
||||||
newest_ts = 0
|
newest_ts = 0
|
||||||
for addr in self._addresses[node_id]:
|
for addr, ts in self._addresses[node_id].items():
|
||||||
newest_ts = max(newest_ts, addr.timestamp)
|
newest_ts = max(newest_ts, ts)
|
||||||
return newest_ts
|
return newest_ts
|
||||||
sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True)
|
sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True)
|
||||||
self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
|
self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
|
||||||
@@ -791,7 +792,10 @@ class ChannelDB(SqlDB):
|
|||||||
graph['nodes'].append(
|
graph['nodes'].append(
|
||||||
nodeinfo._asdict(),
|
nodeinfo._asdict(),
|
||||||
)
|
)
|
||||||
graph['nodes'][-1]['addresses'] = [addr._asdict() for addr in self._addresses[pk]]
|
graph['nodes'][-1]['addresses'] = [
|
||||||
|
{'host': str(addr.host), 'port': addr.port, 'timestamp': ts}
|
||||||
|
for addr, ts in self._addresses[pk].items()
|
||||||
|
]
|
||||||
|
|
||||||
# gather channels
|
# gather channels
|
||||||
for cid, channelinfo in self._channels.items():
|
for cid, channelinfo in self._channels.items():
|
||||||
|
|||||||
@@ -1106,6 +1106,7 @@ def derive_payment_secret_from_payment_preimage(payment_preimage: bytes) -> byte
|
|||||||
|
|
||||||
|
|
||||||
class LNPeerAddr:
|
class LNPeerAddr:
|
||||||
|
# note: while not programmatically enforced, this class is meant to be *immutable*
|
||||||
|
|
||||||
def __init__(self, host: str, port: int, pubkey: bytes):
|
def __init__(self, host: str, port: int, pubkey: bytes):
|
||||||
assert isinstance(host, str), repr(host)
|
assert isinstance(host, str), repr(host)
|
||||||
@@ -1120,7 +1121,7 @@ class LNPeerAddr:
|
|||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.pubkey = pubkey
|
self.pubkey = pubkey
|
||||||
self._net_addr_str = str(net_addr)
|
self._net_addr = net_addr
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return '{}@{}'.format(self.pubkey.hex(), self.net_addr_str())
|
return '{}@{}'.format(self.pubkey.hex(), self.net_addr_str())
|
||||||
@@ -1128,8 +1129,11 @@ class LNPeerAddr:
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'<LNPeerAddr host={self.host} port={self.port} pubkey={self.pubkey.hex()}>'
|
return f'<LNPeerAddr host={self.host} port={self.port} pubkey={self.pubkey.hex()}>'
|
||||||
|
|
||||||
|
def net_addr(self) -> NetAddress:
|
||||||
|
return self._net_addr
|
||||||
|
|
||||||
def net_addr_str(self) -> str:
|
def net_addr_str(self) -> str:
|
||||||
return self._net_addr_str
|
return str(self._net_addr)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not isinstance(other, LNPeerAddr):
|
if not isinstance(other, LNPeerAddr):
|
||||||
|
|||||||
Reference in New Issue
Block a user