1
0

Merge pull request #9542 from f321x/gossip_forwarding

Introduce incoming gossip query handling and forwarding
This commit is contained in:
ThomasV
2025-03-04 14:07:44 +01:00
committed by GitHub
6 changed files with 508 additions and 38 deletions

View File

@@ -43,7 +43,8 @@ from .sql_db import SqlDB, sql
from . import constants, util
from .util import profiler, get_headers_dir, is_ip_address, json_normalize, UserFacingException, is_private_netaddress
from .lntransport import LNPeerAddr
from .lnutil import ShortChannelID, validate_features, IncompatibleOrInsaneFeatures, InvalidGossipMsg
from .lnutil import (ShortChannelID, validate_features, IncompatibleOrInsaneFeatures,
InvalidGossipMsg, GossipForwardingMessage, GossipTimestampFilter)
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
from .lnmsg import decode_msg
from .crypto import sha256d
@@ -68,6 +69,7 @@ class ChannelInfo(NamedTuple):
node1_id: bytes
node2_id: bytes
capacity_sat: Optional[int]
raw: Optional[bytes] = None
@staticmethod
def from_msg(payload: dict) -> 'ChannelInfo':
@@ -82,12 +84,14 @@ class ChannelInfo(NamedTuple):
short_channel_id = ShortChannelID.normalize(channel_id),
node1_id = node_id_1,
node2_id = node_id_2,
capacity_sat = capacity_sat
capacity_sat = capacity_sat,
raw = payload.get('raw')
)
@staticmethod
def from_raw_msg(raw: bytes) -> 'ChannelInfo':
payload_dict = decode_msg(raw)[1]
payload_dict['raw'] = raw
return ChannelInfo.from_msg(payload_dict)
@staticmethod
@@ -111,6 +115,7 @@ class Policy(NamedTuple):
channel_flags: int
message_flags: int
timestamp: int
raw: Optional[bytes] = None
@staticmethod
def from_msg(payload: dict) -> 'Policy':
@@ -124,12 +129,14 @@ class Policy(NamedTuple):
message_flags = int.from_bytes(payload['message_flags'], "big"),
channel_flags = int.from_bytes(payload['channel_flags'], "big"),
timestamp = payload['timestamp'],
raw = payload.get('raw'),
)
@staticmethod
def from_raw_msg(key:bytes, raw: bytes) -> 'Policy':
def from_raw_msg(key: bytes, raw: bytes) -> 'Policy':
payload = decode_msg(raw)[1]
payload['start_node'] = key[8:]
payload['raw'] = raw
return Policy.from_msg(payload)
@staticmethod
@@ -163,6 +170,7 @@ class NodeInfo(NamedTuple):
features: int
timestamp: int
alias: str
raw: Optional[bytes]
@staticmethod
def from_msg(payload) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]:
@@ -182,12 +190,18 @@ class NodeInfo(NamedTuple):
except Exception:
alias = ''
timestamp = payload['timestamp']
node_info = NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias)
node_info = NodeInfo(
node_id=node_id,
features=features,
timestamp=timestamp,
alias=alias,
raw=payload.get('raw'))
return node_info, peer_addrs
@staticmethod
def from_raw_msg(raw: bytes) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]:
payload_dict = decode_msg(raw)[1]
payload_dict['raw'] = raw
return NodeInfo.from_msg(payload_dict)
@staticmethod
@@ -240,6 +254,7 @@ class NodeInfo(NamedTuple):
nonlocal buf
data, buf = buf[0:n], buf[n:]
return data
addresses = []
while buf:
atype = ord(read(1))
@@ -389,6 +404,12 @@ class ChannelDB(SqlDB):
self._chans_with_1_policies = set() # type: Set[ShortChannelID]
self._chans_with_2_policies = set() # type: Set[ShortChannelID]
self.forwarding_lock = threading.RLock()
self.fwd_channels = [] # type: List[GossipForwardingMessage]
self.fwd_orphan_channels = [] # type: List[GossipForwardingMessage]
self.fwd_channel_updates = [] # type: List[GossipForwardingMessage]
self.fwd_node_announcements = [] # type: List[GossipForwardingMessage]
self.data_loaded = asyncio.Event()
self.network = network # only for callback
@@ -487,6 +508,9 @@ class ChannelDB(SqlDB):
self._update_num_policies_for_chan(channel_info.short_channel_id)
if 'raw' in msg:
self._db_save_channel(channel_info.short_channel_id, msg['raw'])
with self.forwarding_lock:
if fwd_msg := GossipForwardingMessage.from_payload(msg):
self.fwd_channels.append(fwd_msg)
def policy_changed(self, old_policy: Policy, new_policy: Policy, verbose: bool) -> bool:
changed = False
@@ -555,6 +579,10 @@ class ChannelDB(SqlDB):
if old_policy and not self.policy_changed(old_policy, policy, verbose):
return UpdateStatus.UNCHANGED
else:
if policy.message_flags & 0b10 == 0: # check if its `dont_forward`
with self.forwarding_lock:
if fwd_msg := GossipForwardingMessage.from_payload(payload):
self.fwd_channel_updates.append(fwd_msg)
return UpdateStatus.GOOD
def add_channel_updates(self, payloads, max_age=None) -> CategorizedChannelUpdates:
@@ -667,7 +695,7 @@ class ChannelDB(SqlDB):
# note: signatures have already been verified.
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
new_nodes = {}
new_nodes = set() # type: Set[bytes]
for msg_payload in msg_payloads:
try:
node_info, node_addresses = NodeInfo.from_msg(msg_payload)
@@ -681,9 +709,7 @@ class ChannelDB(SqlDB):
node = self._nodes.get(node_id)
if node and node.timestamp >= node_info.timestamp:
continue
node = new_nodes.get(node_id)
if node and node.timestamp >= node_info.timestamp:
continue
new_nodes.add(node_id)
# save
with self.lock:
self._nodes[node_id] = node_info
@@ -694,6 +720,9 @@ class ChannelDB(SqlDB):
net_addr = NetAddress(addr.host, addr.port)
self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0
self._db_save_node_addresses(node_addresses)
with self.forwarding_lock:
if fwd_msg := GossipForwardingMessage.from_payload(msg_payload):
self.fwd_node_announcements.append(fwd_msg)
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
self.update_counts()
@@ -995,6 +1024,127 @@ class ChannelDB(SqlDB):
return k
raise Exception('node not found')
def clear_forwarding_gossip(self) -> None:
with self.forwarding_lock:
self.fwd_channels.clear()
self.fwd_channel_updates.clear()
self.fwd_node_announcements.clear()
def filter_orphan_channel_anns(
self, channel_anns: List[GossipForwardingMessage]
) -> Tuple[List, List]:
"""Check if the channel announcements we want to forward have at least 1 update"""
to_forward_anns = []
orphaned_channel_anns = []
for channel in channel_anns:
if channel.scid is None:
continue
elif (channel.scid in self._chans_with_1_policies
or channel.scid in self._chans_with_2_policies):
to_forward_anns.append(channel)
continue
orphaned_channel_anns.append(channel)
return to_forward_anns, orphaned_channel_anns
def set_fwd_channel_anns_ts(self, channel_anns: List[GossipForwardingMessage]) \
-> List[GossipForwardingMessage]:
"""Set the timestamps of the passed channel announcements from the corresponding policies"""
timestamped_chan_anns: List[GossipForwardingMessage] = []
with self.lock:
policies = self._policies.copy()
channels = self._channels.copy()
for chan_ann in channel_anns:
if chan_ann.timestamp is not None:
timestamped_chan_anns.append(chan_ann)
continue
scid = chan_ann.scid
if (channel_info := channels.get(scid)) is None:
continue
policy1 = policies.get((channel_info.node1_id, scid))
policy2 = policies.get((channel_info.node2_id, scid))
potential_timestamps = []
for policy in [policy1, policy2]:
if policy is not None:
potential_timestamps.append(policy.timestamp)
if not potential_timestamps:
continue
chan_ann.timestamp = min(potential_timestamps)
timestamped_chan_anns.append(chan_ann)
return timestamped_chan_anns
def get_forwarding_gossip_batch(self) -> List[GossipForwardingMessage]:
with self.forwarding_lock:
fwd_gossip = self.fwd_channel_updates + self.fwd_node_announcements
channel_anns = self.fwd_channels.copy()
self.clear_forwarding_gossip()
fwd_chan_anns1, _ = self.filter_orphan_channel_anns(self.fwd_orphan_channels)
fwd_chan_anns2, self.fwd_orphan_channels = self.filter_orphan_channel_anns(channel_anns)
channel_anns = self.set_fwd_channel_anns_ts(fwd_chan_anns1 + fwd_chan_anns2)
return channel_anns + fwd_gossip
def get_gossip_in_timespan(self, timespan: GossipTimestampFilter) \
-> List[GossipForwardingMessage]:
"""Return a list of gossip messages matching the requested timespan."""
forwarding_gossip = []
with self.lock:
chans = self._channels.copy()
policies = self._policies.copy()
nodes = self._nodes.copy()
for short_id, chan in chans.items():
# fetching the timestamp from the channel update (according to BOLT-07)
chan_up_n1 = policies.get((chan.node1_id, short_id))
chan_up_n2 = policies.get((chan.node2_id, short_id))
updates = []
for policy in [chan_up_n1, chan_up_n2]:
if policy and policy.raw and timespan.in_range(policy.timestamp):
if policy.message_flags & 0b10 == 0: # check that its not "dont_forward"
updates.append(GossipForwardingMessage(
msg=policy.raw,
timestamp=policy.timestamp))
if not updates or chan.raw is None:
continue
chan_ann_ts = min(update.timestamp for update in updates)
channel_announcement = GossipForwardingMessage(msg=chan.raw, timestamp=chan_ann_ts)
forwarding_gossip.extend([channel_announcement] + updates)
for node_ann in nodes.values():
if timespan.in_range(node_ann.timestamp) and node_ann.raw:
forwarding_gossip.append(GossipForwardingMessage(
msg=node_ann.raw,
timestamp=node_ann.timestamp))
return forwarding_gossip
def get_channels_in_range(self, first_blocknum: int, number_of_blocks: int) -> List[ShortChannelID]:
with self.lock:
channels = self._channels.copy()
scids: List[ShortChannelID] = []
for scid in channels:
if first_blocknum <= scid.block_height < first_blocknum + number_of_blocks:
scids.append(scid)
scids.sort()
return scids
def get_gossip_for_scid_request(self, scid: ShortChannelID) -> List[bytes]:
requested_gossip = []
chan_ann = self._channels.get(scid)
if not chan_ann or not chan_ann.raw:
return []
chan_up1 = self._policies.get((chan_ann.node1_id, scid))
chan_up2 = self._policies.get((chan_ann.node2_id, scid))
node_ann1 = self._nodes.get(chan_ann.node1_id)
node_ann2 = self._nodes.get(chan_ann.node2_id)
for msg in [chan_ann, chan_up1, chan_up2, node_ann1, node_ann2]:
if msg and msg.raw:
requested_gossip.append(msg.raw)
return requested_gossip
def to_dict(self) -> dict:
""" Generates a graph representation in terms of a dictionary.

View File

@@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict
import asyncio
import os
import time
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable, Awaitable
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable, Awaitable, List
from datetime import datetime
import functools
@@ -45,7 +45,8 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConf
RemoteMisbehaving, ShortChannelID,
IncompatibleLightningFeatures, derive_payment_secret_from_payment_preimage,
ChannelType, LNProtocolWarning, validate_features,
IncompatibleOrInsaneFeatures, FeeBudgetExceeded)
IncompatibleOrInsaneFeatures, FeeBudgetExceeded,
GossipForwardingMessage, GossipTimestampFilter)
from .lnutil import FeeUpdate, channel_id_from_funding_tx, PaymentFeeBudget
from .lnutil import serialize_htlc_key, Keypair
from .lntransport import LNTransport, LNTransportBase, LightningPeerConnectionClosed, HandshakeFailed
@@ -74,7 +75,9 @@ class Peer(Logger, EventListener):
ORDERED_MESSAGES = (
'accept_channel', 'funding_signed', 'funding_created', 'accept_channel', 'closing_signed')
SPAMMY_MESSAGES = (
'ping', 'pong', 'channel_announcement', 'node_announcement', 'channel_update',)
'ping', 'pong', 'channel_announcement', 'node_announcement', 'channel_update',
'gossip_timestamp_filter', 'reply_channel_range', 'query_channel_range',
'query_short_channel_ids', 'reply_short_channel_ids', 'reply_short_channel_ids_end')
DELAY_INC_MSG_PROCESSING_SLEEP = 0.01
@@ -106,6 +109,8 @@ class Peer(Logger, EventListener):
self.reply_channel_range = asyncio.Queue()
# gossip uses a single queue to preserve message order
self.gossip_queue = asyncio.Queue()
self.gossip_timestamp_filter = None # type: Optional[GossipTimestampFilter]
self.outgoing_gossip_reply = False # type: bool
self.ordered_message_queues = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue] # for messages that are ordered
self.temp_id_to_id = {} # type: Dict[bytes, Optional[bytes]] # to forward error messages
self.funding_created_sent = set() # for channels in PREOPENING
@@ -168,8 +173,7 @@ class Peer(Logger, EventListener):
await self.transport.handshake()
self.logger.info(f"handshake done for {self.transport.peer_addr or self.pubkey.hex()}")
features = self.features.for_init_message()
b = int.bit_length(features)
flen = b // 8 + int(bool(b % 8))
flen = features.min_len()
self.send_message(
"init", gflen=0, flen=flen,
features=features,
@@ -240,6 +244,7 @@ class Peer(Logger, EventListener):
# raw message is needed to check signature
if message_type in ['node_announcement', 'channel_announcement', 'channel_update']:
payload['raw'] = message
payload['sender_node_id'] = self.pubkey
# note: the message handler might be async or non-async. In either case, by default,
# we wait for it to complete before we return, i.e. before the next message is processed.
if asyncio.iscoroutinefunction(f):
@@ -294,7 +299,7 @@ class Peer(Logger, EventListener):
return
raise GracefulDisconnect
async def send_warning(self, channel_id: bytes, message: str = None, *, close_connection=False):
def send_warning(self, channel_id: bytes, message: str = None, *, close_connection=False):
"""Sends a warning and disconnects if close_connection.
Note:
@@ -313,7 +318,7 @@ class Peer(Logger, EventListener):
if close_connection:
raise GracefulDisconnect
async def send_error(self, channel_id: bytes, message: str = None, *, force_close_channel=False):
def send_error(self, channel_id: bytes, message: str = None, *, force_close_channel=False):
"""Sends an error message and force closes the channel.
Note:
@@ -406,6 +411,42 @@ class Peer(Logger, EventListener):
if not self.lnworker.uses_trampoline():
self.gossip_queue.put_nowait(('channel_update', payload))
def on_query_channel_range(self, payload):
if self.lnworker == self.lnworker.network.lngossip or not self._should_forward_gossip():
return
if not self._is_valid_channel_range_query(payload):
return self.send_warning(bytes(32), "received invalid query_channel_range")
if self.outgoing_gossip_reply:
return self.send_warning(bytes(32), "received multiple queries at the same time")
self.outgoing_gossip_reply = True
self.gossip_queue.put_nowait(('query_channel_range', payload))
def on_query_short_channel_ids(self, payload):
if self.lnworker == self.lnworker.network.lngossip or not self._should_forward_gossip():
return
if self.outgoing_gossip_reply:
return self.send_warning(bytes(32), "received multiple queries at the same time")
if not self._is_valid_short_channel_id_query(payload):
return self.send_warning(bytes(32), "invalid query_short_channel_ids")
self.outgoing_gossip_reply = True
self.gossip_queue.put_nowait(('query_short_channel_ids', payload))
def on_gossip_timestamp_filter(self, payload):
if self._should_forward_gossip():
self.set_gossip_timestamp_filter(payload)
def set_gossip_timestamp_filter(self, payload: dict) -> None:
"""Set the gossip_timestamp_filter for this peer. If the peer requested historical gossip,
the request is put on the queue, otherwise only the forwarding loop will check the filter"""
if payload.get('chain_hash') != constants.net.rev_genesis_bytes():
return
filter = GossipTimestampFilter.from_payload(payload)
self.gossip_timestamp_filter = filter
self.logger.debug(f"got gossip_ts_filter from peer {self.pubkey.hex()}: "
f"{str(self.gossip_timestamp_filter)}")
if filter and not filter.only_forwarding:
self.gossip_queue.put_nowait(('gossip_timestamp_filter', None))
def maybe_save_remote_update(self, payload):
if not self.channels:
return
@@ -424,6 +465,8 @@ class Peer(Logger, EventListener):
# This code assumes gossip_queries is set. BOLT7: "if the
# gossip_queries feature is negotiated, [a node] MUST NOT
# send gossip it did not generate itself"
# NOTE: The definition of gossip_queries changed
# https://github.com/lightning/bolts/commit/fce8bab931674a81a9ea895c9e9162e559e48a65
short_channel_id = ShortChannelID(payload['short_channel_id'])
self.logger.info(f'received orphan channel update {short_channel_id}')
self.orphan_channel_updates[short_channel_id] = payload
@@ -462,13 +505,14 @@ class Peer(Logger, EventListener):
@handle_disconnect
async def main_loop(self):
async with self.taskgroup as group:
await group.spawn(self._message_loop())
await group.spawn(self.htlc_switch())
await group.spawn(self.query_gossip())
await group.spawn(self.process_gossip())
await group.spawn(self.send_own_gossip())
await group.spawn(self._message_loop())
await group.spawn(self._query_gossip())
await group.spawn(self._process_gossip())
await group.spawn(self._send_own_gossip())
await group.spawn(self._forward_gossip())
async def process_gossip(self):
async def _process_gossip(self):
while True:
await asyncio.sleep(5)
if not self.network.lngossip:
@@ -484,6 +528,12 @@ class Peer(Logger, EventListener):
chan_upds.append(payload)
elif name == 'node_announcement':
node_anns.append(payload)
elif name == 'query_channel_range':
await self.taskgroup.spawn(self._send_reply_channel_range(payload))
elif name == 'query_short_channel_ids':
await self.taskgroup.spawn(self._send_reply_short_channel_ids(payload))
elif name == 'gossip_timestamp_filter':
await self.taskgroup.spawn(self._handle_historical_gossip_request())
else:
raise Exception('unknown message')
if self.gossip_queue.empty():
@@ -491,7 +541,7 @@ class Peer(Logger, EventListener):
if self.network.lngossip:
await self.network.lngossip.process_gossip(chan_anns, node_anns, chan_upds)
async def send_own_gossip(self):
async def _send_own_gossip(self):
if self.lnworker == self.lnworker.network.lngossip:
return
await asyncio.sleep(10)
@@ -505,18 +555,76 @@ class Peer(Logger, EventListener):
self.maybe_send_channel_announcement(chan)
await asyncio.sleep(600)
async def query_gossip(self):
def _should_forward_gossip(self) -> bool:
if (self.network.lngossip != self.lnworker
and not self.lnworker.uses_trampoline()
and self.features.supports(LnFeatures.GOSSIP_QUERIES_REQ)):
return True
return False
async def _forward_gossip(self):
if not self._should_forward_gossip():
return
async def send_new_gossip_with_semaphore(gossip: List[GossipForwardingMessage]):
async with self.network.lngossip.gossip_request_semaphore:
sent = await self._send_gossip_messages(gossip)
if sent > 0:
self.logger.debug(f"forwarded {sent} gossip messages to {self.pubkey.hex()}")
lngossip = self.network.lngossip
last_gossip_batch_ts = 0
while True:
await asyncio.sleep(10)
if not self.gossip_timestamp_filter:
continue # peer didn't request gossip
new_gossip, last_lngossip_refresh_ts = await lngossip.get_forwarding_gossip()
if not last_lngossip_refresh_ts > last_gossip_batch_ts:
continue # no new batch available
last_gossip_batch_ts = last_lngossip_refresh_ts
await self.taskgroup.spawn(send_new_gossip_with_semaphore(new_gossip))
async def _handle_historical_gossip_request(self):
"""Called when a peer requests historical gossip with a gossip_timestamp_filter query."""
filter = self.gossip_timestamp_filter
if not self._should_forward_gossip() or not filter or filter.only_forwarding:
return
async with self.network.lngossip.gossip_request_semaphore:
requested_gossip = self.lnworker.channel_db.get_gossip_in_timespan(filter)
filter.only_forwarding = True
sent = await self._send_gossip_messages(requested_gossip)
if sent > 0:
self.logger.debug(f"forwarded {sent} historical gossip messages to {self.pubkey.hex()}")
async def _send_gossip_messages(self, messages: List[GossipForwardingMessage]) -> int:
amount_sent = 0
for msg in messages:
if self.gossip_timestamp_filter.in_range(msg.timestamp) \
and self.pubkey != msg.sender_node_id:
await self.transport.send_bytes_and_drain(msg.msg)
amount_sent += 1
if amount_sent % 250 == 0:
# this can be a lot of messages, completely blocking the event loop
await asyncio.sleep(self.DELAY_INC_MSG_PROCESSING_SLEEP)
return amount_sent
async def _query_gossip(self):
try:
await util.wait_for2(self.initialized, LN_P2P_NETWORK_TIMEOUT)
except Exception as e:
raise GracefulDisconnect(f"Failed to initialize: {e!r}") from e
if self.lnworker == self.lnworker.network.lngossip:
if not self.their_features.supports(LnFeatures.GOSSIP_QUERIES_OPT):
raise GracefulDisconnect("remote does not support gossip_queries, which we need")
try:
ids, complete = await util.wait_for2(self.get_channel_range(), LN_P2P_NETWORK_TIMEOUT)
except asyncio.TimeoutError as e:
raise GracefulDisconnect("query_channel_range timed out") from e
self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete))
await self.lnworker.add_new_ids(ids)
self.request_gossip(int(time.time()))
while True:
todo = self.lnworker.get_ids_to_query()
if not todo:
@@ -524,6 +632,68 @@ class Peer(Logger, EventListener):
continue
await self.get_short_channel_ids(todo)
@staticmethod
def _is_valid_channel_range_query(payload: dict) -> bool:
if payload.get('chain_hash') != constants.net.rev_genesis_bytes():
return False
if payload.get('first_blocknum', -1) < constants.net.BLOCK_HEIGHT_FIRST_LIGHTNING_CHANNELS:
return False
if payload.get('number_of_blocks', 0) < 1:
return False
return True
def _is_valid_short_channel_id_query(self, payload: dict) -> bool:
if payload.get('chain_hash') != constants.net.rev_genesis_bytes():
return False
enc_short_ids = payload['encoded_short_ids']
if enc_short_ids[0] != 0:
self.logger.debug(f"got query_short_channel_ids with invalid encoding: {repr(enc_short_ids[0])}")
return False
if (len(enc_short_ids) - 1) % 8 != 0:
self.logger.debug(f"got query_short_channel_ids with invalid length")
return False
return True
async def _send_reply_channel_range(self, payload: dict):
"""https://github.com/lightning/bolts/blob/acd383145dd8c3fecd69ce94e4a789767b984ac0/07-routing-gossip.md#requirements-5"""
first_blockheight: int = payload['first_blocknum']
async with self.network.lngossip.gossip_request_semaphore:
sorted_scids: List[ShortChannelID] = self.lnworker.channel_db.get_channels_in_range(
first_blockheight,
payload['number_of_blocks']
)
self.logger.debug(f"reply_channel_range to request "
f"first_height={first_blockheight}, "
f"num_blocks={payload['number_of_blocks']}, "
f"sending {len(sorted_scids)} scids")
complete: bool = False
while not complete:
# create a 64800 byte chunk of skids, split the remaining scids
encoded_scids, sorted_scids = b''.join(sorted_scids[:8100]), sorted_scids[8100:]
complete = len(sorted_scids) == 0 # if there are no scids remaining we are done
# number of blocks covered by the scids in this chunk
if complete:
# LAST MESSAGE MUST have first_blocknum plus number_of_blocks equal or greater than
# the query_channel_range first_blocknum plus number_of_blocks.
number_of_blocks = ((payload['first_blocknum'] + payload['number_of_blocks'])
- first_blockheight)
else:
# we cover the range until the height of the first scid in the next chunk
number_of_blocks = sorted_scids[0].block_height - first_blockheight
self.send_message('reply_channel_range',
chain_hash=constants.net.rev_genesis_bytes(),
first_blocknum=first_blockheight,
number_of_blocks=number_of_blocks,
sync_complete=complete,
len=1+len(encoded_scids),
encoded_short_ids=b'\x00' + encoded_scids)
if not complete:
first_blockheight = sorted_scids[0].block_height
await asyncio.sleep(self.DELAY_INC_MSG_PROCESSING_SLEEP)
self.outgoing_gossip_reply = False
async def get_channel_range(self):
first_block = constants.net.BLOCK_HEIGHT_FIRST_LIGHTNING_CHANNELS
num_blocks = self.lnworker.network.get_local_height() - first_block
@@ -544,6 +714,7 @@ class Peer(Logger, EventListener):
# on_reply_channel_range. >>> first_block 497000, num_blocks 79038, num_ids 8000, complete False
# on_reply_channel_range. >>> first_block 497000, num_blocks 79038, num_ids 8000, complete False
# on_reply_channel_range. >>> first_block 497000, num_blocks 79038, num_ids 5344, complete True
# ADDENDUM (01/2025): now it's 'MUST set sync_complete to false if this is not the final reply_channel_range.'
while True:
index, num, complete, _ids = await self.reply_channel_range.get()
ids.update(_ids)
@@ -599,9 +770,34 @@ class Peer(Logger, EventListener):
complete = bool(int.from_bytes(payload['sync_complete'], 'big'))
encoded = payload['encoded_short_ids']
ids = self.decode_short_ids(encoded)
#self.logger.info(f"on_reply_channel_range. >>> first_block {first}, num_blocks {num}, num_ids {len(ids)}, complete {repr(payload['complete'])}")
# self.logger.info(f"on_reply_channel_range. >>> first_block {first}, num_blocks {num}, "
# f"num_ids {len(ids)}, complete {complete}")
self.reply_channel_range.put_nowait((first, num, complete, ids))
async def _send_reply_short_channel_ids(self, payload: dict):
async with self.network.lngossip.gossip_request_semaphore:
requested_scids = payload['encoded_short_ids']
decoded_scids = [ShortChannelID.normalize(scid)
for scid in self.decode_short_ids(requested_scids)]
self.logger.debug(f"serving query_short_channel_ids request: "
f"requested {len(decoded_scids)} scids")
chan_db = self.lnworker.channel_db
response: Set[bytes] = set()
for scid in decoded_scids:
requested_msgs = chan_db.get_gossip_for_scid_request(scid)
response.update(requested_msgs)
self.logger.debug(f"found {len(response)} gossip messages to serve scid request")
for index, msg in enumerate(response):
await self.transport.send_bytes_and_drain(msg)
if index % 250 == 0:
await asyncio.sleep(self.DELAY_INC_MSG_PROCESSING_SLEEP)
self.send_message(
'reply_short_channel_ids_end',
chain_hash=constants.net.rev_genesis_bytes(),
full_information=self.network.lngossip.is_synced()
)
self.outgoing_gossip_reply = False
async def get_short_channel_ids(self, ids):
self.logger.info(f'Querying {len(ids)} short_channel_ids')
assert not self.querying.is_set()
@@ -982,7 +1178,7 @@ class Peer(Logger, EventListener):
try:
chan.receive_new_commitment(remote_sig, [])
except LNProtocolWarning as e:
await self.send_warning(channel_id, message=str(e), close_connection=True)
self.send_warning(channel_id, message=str(e), close_connection=True)
chan.open_with_first_pcp(remote_per_commitment_point, remote_sig)
chan.set_state(ChannelState.OPENING)
if zeroconf:
@@ -1179,7 +1375,7 @@ class Peer(Logger, EventListener):
try:
chan.receive_new_commitment(remote_sig, [])
except LNProtocolWarning as e:
await self.send_warning(channel_id, message=str(e), close_connection=True)
self.send_warning(channel_id, message=str(e), close_connection=True)
sig_64, _ = chan.sign_next_commitment()
self.send_message('funding_signed',
channel_id=channel_id,
@@ -1509,8 +1705,7 @@ class Peer(Logger, EventListener):
timestamp = int(time.time())
node_id = privkey_to_pubkey(self.privkey)
features = self.features.for_node_announcement()
b = int.bit_length(features)
flen = b // 8 + int(bool(b % 8))
flen = features.min_len()
rgb_color = bytes.fromhex('000000')
alias = bytes(alias, 'utf8')
alias += bytes(32 - len(alias))
@@ -2444,7 +2639,7 @@ class Peer(Logger, EventListener):
# BOLT-02 check if they use the upfront shutdown script they advertised
if self.is_upfront_shutdown_script() and their_upfront_scriptpubkey:
if not (their_scriptpubkey == their_upfront_scriptpubkey):
await self.send_warning(
self.send_warning(
chan.channel_id,
"remote didn't use upfront shutdown script it committed to in channel opening",
close_connection=True)
@@ -2455,7 +2650,7 @@ class Peer(Logger, EventListener):
elif match_script_against_template(their_scriptpubkey, transaction.SCRIPTPUBKEY_TEMPLATE_WITNESS_V0):
pass
else:
await self.send_warning(
self.send_warning(
chan.channel_id,
f'scriptpubkey in received shutdown message does not conform to any template: {their_scriptpubkey.hex()}',
close_connection=True)

View File

@@ -199,6 +199,9 @@ class LNTransportBase:
privkey: bytes
peer_addr: Optional[LNPeerAddr] = None
def __init__(self):
self.drain_write_lock = asyncio.Lock()
def name(self) -> str:
pubkey = self.remote_pubkey()
pubkey_hex = pubkey.hex() if pubkey else pubkey
@@ -218,6 +221,12 @@ class LNTransportBase:
assert len(c) == len(msg) + 16
self.writer.write(lc+c)
async def send_bytes_and_drain(self, msg: bytes) -> None:
"""Should be used when possible (in async scope), to avoid memory exhaustion."""
async with self.drain_write_lock:
self.send_bytes(msg)
await self.writer.drain()
async def read_messages(self):
buffer = bytearray()
while True:

View File

@@ -9,6 +9,7 @@ from collections import namedtuple, defaultdict
from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict, Set, Sequence
import re
import sys
import time
import electrum_ecc as ecc
from electrum_ecc import CURVE_ORDER, ecdsa_sig64_from_der_sig, ECPubkey, string_to_number
@@ -1509,6 +1510,10 @@ class LnFeatures(IntFlag):
features |= (1 << flag)
return features
def min_len(self) -> int:
b = int.bit_length(self)
return b // 8 + int(bool(b % 8))
def supports(self, feature: 'LnFeatures') -> bool:
"""Returns whether given feature is enabled.
@@ -1634,6 +1639,56 @@ def get_ln_flag_pair_of_bit(flag_bit: int) -> int:
return flag_bit - 1
class GossipTimestampFilter:
def __init__(self, first_timestamp: int, timestamp_range: int):
self.first_timestamp = first_timestamp
self.timestamp_range = timestamp_range
# True once we sent them the requested gossip and only forward
self.only_forwarding = False
if first_timestamp >= int(time.time()) - 20:
self.only_forwarding = True
def __str__(self):
return (f"GossipTimestampFilter | first_timestamp={self.first_timestamp} | "
f"timestamp_range={self.timestamp_range}")
def in_range(self, timestamp: int) -> bool:
return self.first_timestamp <= timestamp < self.first_timestamp + self.timestamp_range
@classmethod
def from_payload(cls, payload) -> Optional['GossipTimestampFilter']:
try:
first_timestamp = payload['first_timestamp']
timestamp_range = payload['timestamp_range']
except KeyError:
return None
if first_timestamp >= 0xFFFFFFFF:
return None
return cls(first_timestamp, timestamp_range)
class GossipForwardingMessage:
def __init__(self,
msg: bytes,
scid: Optional[ShortChannelID] = None,
timestamp: Optional[int] = None,
sender_node_id: Optional[bytes] = None):
self.scid: Optional[ShortChannelID] = scid
self.sender_node_id: Optional[bytes] = sender_node_id
self.msg = msg
self.timestamp = timestamp
@classmethod
def from_payload(cls, payload: dict) -> Optional['GossipForwardingMessage']:
try:
msg = payload['raw']
scid = ShortChannelID.normalize(payload.get('short_channel_id'))
sender_node_id = payload.get('sender_node_id')
timestamp = payload.get('timestamp')
except KeyError:
return None
return cls(msg, scid, timestamp, sender_node_id)
def list_enabled_ln_feature_bits(features: int) -> tuple[int, ...]:
"""Returns a list of enabled feature bits. If both opt and req are set, only
req will be included in the result."""

View File

@@ -60,7 +60,7 @@ from .lnutil import (
LnKeyFamily, LOCAL, REMOTE, MIN_FINAL_CLTV_DELTA_FOR_INVOICE, SENT, RECEIVED, HTLCOwner, UpdateAddHtlc, LnFeatures,
ShortChannelID, HtlcLog, NoPathFound, InvalidGossipMsg, FeeBudgetExceeded, ImportedChannelBackupStorage,
OnchainChannelBackupStorage, ln_compare_features, IncompatibleLightningFeatures, PaymentFeeBudget,
NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE
NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage
)
from .lnonion import decode_onion_error, OnionFailureCode, OnionRoutingFailure, OnionPacket
from .lnmsg import decode_msg
@@ -162,7 +162,6 @@ LNWALLET_FEATURES = (
BASE_FEATURES
| LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ
| LnFeatures.OPTION_STATIC_REMOTEKEY_REQ
| LnFeatures.GOSSIP_QUERIES_REQ
| LnFeatures.VAR_ONION_REQ
| LnFeatures.PAYMENT_SECRET_REQ
| LnFeatures.BASIC_MPP_OPT
@@ -175,8 +174,10 @@ LNWALLET_FEATURES = (
LNGOSSIP_FEATURES = (
BASE_FEATURES
| LnFeatures.GOSSIP_QUERIES_OPT
# LNGossip doesn't serve gossip but weirdly have to signal so
# that peers satisfy our queries
| LnFeatures.GOSSIP_QUERIES_REQ
| LnFeatures.GOSSIP_QUERIES_OPT
)
@@ -290,7 +291,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
peer_addr = LNPeerAddr(host, port, node_id)
self._trying_addr_now(peer_addr)
self.logger.info(f"adding peer {peer_addr}")
if node_id == self.node_keypair.pubkey:
if node_id == self.node_keypair.pubkey or self.is_our_lnwallet(node_id):
raise ErrorAddingPeer("cannot connect to self")
transport = LNTransport(self.node_keypair.privkey, peer_addr,
e_proxy=ESocksProxy.from_network_settings(self.network))
@@ -327,6 +328,14 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
def num_peers(self) -> int:
return sum([p.is_initialized() for p in self.peers.values()])
def is_our_lnwallet(self, node_id: bytes) -> bool:
"""Check if node_id is one of our own wallets"""
wallets = self.network.daemon.get_wallets()
for wallet in wallets.values():
if wallet.lnworker and wallet.lnworker.node_keypair.pubkey == node_id:
return True
return False
def start_network(self, network: 'Network'):
assert network
assert self.network is None, "already started"
@@ -511,6 +520,12 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
class LNGossip(LNWorker):
"""The LNGossip class is a separate, unannounced Lightning node with random id that is just querying
gossip from other nodes. The LNGossip node does not satisfy gossip queries, this is done by the
LNWallet class(es). LNWallets are the advertised nodes used for actual payments and only satisfy
peer queries without fetching gossip themselves. This separation is done so that gossip can be queried
independently of the active LNWallets. LNGossip keeps a curated batch of gossip in _forwarding_gossip
that is fetched by the LNWallets for regular forwarding."""
max_age = 14*24*3600
LOGGING_SHORTCUT = 'g'
@@ -521,12 +536,17 @@ class LNGossip(LNWorker):
node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
LNWorker.__init__(self, node_keypair, LNGOSSIP_FEATURES, config=config)
self.unknown_ids = set()
self._forwarding_gossip = [] # type: List[GossipForwardingMessage]
self._last_gossip_batch_ts = 0 # type: int
self._forwarding_gossip_lock = asyncio.Lock()
self.gossip_request_semaphore = asyncio.Semaphore(5)
def start_network(self, network: 'Network'):
super().start_network(network)
for coro in [
self._maintain_connectivity(),
self.maintain_db(),
self._maintain_forwarding_gossip()
]:
tg_coro = self.taskgroup.spawn(coro)
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
@@ -539,6 +559,20 @@ class LNGossip(LNWorker):
self.channel_db.prune_orphaned_channels()
await asyncio.sleep(120)
async def _maintain_forwarding_gossip(self):
await self.channel_db.data_loaded.wait()
await self.wait_for_sync()
while True:
async with self._forwarding_gossip_lock:
self._forwarding_gossip = self.channel_db.get_forwarding_gossip_batch()
self._last_gossip_batch_ts = int(time.time())
self.logger.debug(f"{len(self._forwarding_gossip)} gossip messages available to forward")
await asyncio.sleep(60)
async def get_forwarding_gossip(self) -> tuple[List[GossipForwardingMessage], int]:
async with self._forwarding_gossip_lock:
return self._forwarding_gossip, self._last_gossip_batch_ts
async def add_new_ids(self, ids: Iterable[bytes]):
known = self.channel_db.get_channel_ids()
new = set(ids) - set(known)
@@ -563,12 +597,19 @@ class LNGossip(LNWorker):
return None, None, None
nchans_with_0p, nchans_with_1p, nchans_with_2p = self.channel_db.get_num_channels_partitioned_by_policy_count()
num_db_channels = nchans_with_0p + nchans_with_1p + nchans_with_2p
num_nodes = self.channel_db.num_nodes
num_nodes_associated_to_chans = max(len(self.channel_db._channels_for_node.keys()), 1)
# some channels will never have two policies (only one is in gossip?...)
# so if we have at least 1 policy for a channel, we consider that channel "complete" here
current_est = num_db_channels - nchans_with_0p
total_est = len(self.unknown_ids) + num_db_channels
progress = current_est / total_est if total_est and current_est else 0
progress_chans = current_est / total_est if total_est and current_est else 0
# consider that we got at least 10% of the node anns of node ids we know about
progress_nodes = min((num_nodes / num_nodes_associated_to_chans) * 10, 1)
progress = (progress_chans * 3 + progress_nodes) / 4 # weigh the channel progress higher
# self.logger.debug(f"Sync process chans: {progress_chans} | Progress nodes: {progress_nodes} | "
# f"Total progress: {progress} | NUM_NODES: {num_nodes} / {num_nodes_associated_to_chans}")
progress_percent = (1.0 / 0.95 * progress) * 100
progress_percent = min(progress_percent, 100)
progress_percent = round(progress_percent)
@@ -582,8 +623,8 @@ class LNGossip(LNWorker):
# note: we run in the originating peer's TaskGroup, so we can safely raise here
# and disconnect only from that peer
await self.channel_db.data_loaded.wait()
self.logger.debug(f'process_gossip {len(chan_anns)} {len(node_anns)} {len(chan_upds)}')
self.logger.debug(f'process_gossip ca: {len(chan_anns)} na: {len(node_anns)} '
f'cu: {len(chan_upds)}')
# channel announcements
def process_chan_anns():
for payload in chan_anns:
@@ -610,6 +651,24 @@ class LNGossip(LNWorker):
if categorized_chan_upds.good:
self.logger.debug(f'process_gossip: {len(categorized_chan_upds.good)}/{len(chan_upds)}')
def is_synced(self) -> bool:
_, _, percentage_synced = self.get_sync_progress_estimate()
if percentage_synced is not None and percentage_synced >= 100:
return True
return False
async def wait_for_sync(self, times_to_check: int = 3):
"""Check if we have 100% sync progress `times_to_check` times in a row (because the
estimate often jumps back after some seconds when doing initial sync)."""
while True:
if self.is_synced():
times_to_check -= 1
if times_to_check <= 0:
return
await asyncio.sleep(10)
# flush the gossip queue so we don't forward old gossip after sync is complete
self.channel_db.get_forwarding_gossip_batch()
class PaySession(Logger):
def __init__(
@@ -771,6 +830,8 @@ class LNWallet(LNWorker):
features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT
if self.config.ACCEPT_ZEROCONF_CHANNELS:
features |= LnFeatures.OPTION_ZEROCONF_OPT
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP:
features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch
LNWorker.__init__(self, self.node_keypair, features, config=self.config)
self.lnwatcher = None
self.lnrater: LNRater = None

View File

@@ -1247,7 +1247,7 @@ class TestPeerDirect(TestPeer):
async def action():
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
await p1.send_warning(alice_channel.channel_id, 'be warned!', close_connection=True)
p1.send_warning(alice_channel.channel_id, 'be warned!', close_connection=True)
gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
with self.assertRaises(GracefulDisconnect):
await gath
@@ -1259,7 +1259,7 @@ class TestPeerDirect(TestPeer):
async def action():
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
await p1.send_error(alice_channel.channel_id, 'some error happened!', force_close_channel=True)
p1.send_error(alice_channel.channel_id, 'some error happened!', force_close_channel=True)
assert alice_channel.is_closed()
gath.cancel()
gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())