diff --git a/electrum/channel_db.py b/electrum/channel_db.py index cc81be62c..42ca2626a 100644 --- a/electrum/channel_db.py +++ b/electrum/channel_db.py @@ -43,7 +43,8 @@ from .sql_db import SqlDB, sql from . import constants, util from .util import profiler, get_headers_dir, is_ip_address, json_normalize, UserFacingException, is_private_netaddress from .lntransport import LNPeerAddr -from .lnutil import ShortChannelID, validate_features, IncompatibleOrInsaneFeatures, InvalidGossipMsg +from .lnutil import (ShortChannelID, validate_features, IncompatibleOrInsaneFeatures, + InvalidGossipMsg, GossipForwardingMessage, GossipTimestampFilter) from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update from .lnmsg import decode_msg from .crypto import sha256d @@ -68,6 +69,7 @@ class ChannelInfo(NamedTuple): node1_id: bytes node2_id: bytes capacity_sat: Optional[int] + raw: Optional[bytes] = None @staticmethod def from_msg(payload: dict) -> 'ChannelInfo': @@ -82,12 +84,14 @@ class ChannelInfo(NamedTuple): short_channel_id = ShortChannelID.normalize(channel_id), node1_id = node_id_1, node2_id = node_id_2, - capacity_sat = capacity_sat + capacity_sat = capacity_sat, + raw = payload.get('raw') ) @staticmethod def from_raw_msg(raw: bytes) -> 'ChannelInfo': payload_dict = decode_msg(raw)[1] + payload_dict['raw'] = raw return ChannelInfo.from_msg(payload_dict) @staticmethod @@ -111,6 +115,7 @@ class Policy(NamedTuple): channel_flags: int message_flags: int timestamp: int + raw: Optional[bytes] = None @staticmethod def from_msg(payload: dict) -> 'Policy': @@ -124,12 +129,14 @@ class Policy(NamedTuple): message_flags = int.from_bytes(payload['message_flags'], "big"), channel_flags = int.from_bytes(payload['channel_flags'], "big"), timestamp = payload['timestamp'], + raw = payload.get('raw'), ) @staticmethod - def from_raw_msg(key:bytes, raw: bytes) -> 'Policy': + def from_raw_msg(key: bytes, raw: bytes) -> 'Policy': payload = decode_msg(raw)[1] payload['start_node'] = key[8:] + payload['raw'] = raw return Policy.from_msg(payload) @staticmethod @@ -163,6 +170,7 @@ class NodeInfo(NamedTuple): features: int timestamp: int alias: str + raw: Optional[bytes] @staticmethod def from_msg(payload) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]: @@ -182,12 +190,18 @@ class NodeInfo(NamedTuple): except Exception: alias = '' timestamp = payload['timestamp'] - node_info = NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias) + node_info = NodeInfo( + node_id=node_id, + features=features, + timestamp=timestamp, + alias=alias, + raw=payload.get('raw')) return node_info, peer_addrs @staticmethod def from_raw_msg(raw: bytes) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]: payload_dict = decode_msg(raw)[1] + payload_dict['raw'] = raw return NodeInfo.from_msg(payload_dict) @staticmethod @@ -240,6 +254,7 @@ class NodeInfo(NamedTuple): nonlocal buf data, buf = buf[0:n], buf[n:] return data + addresses = [] while buf: atype = ord(read(1)) @@ -389,6 +404,12 @@ class ChannelDB(SqlDB): self._chans_with_1_policies = set() # type: Set[ShortChannelID] self._chans_with_2_policies = set() # type: Set[ShortChannelID] + self.forwarding_lock = threading.RLock() + self.fwd_channels = [] # type: List[GossipForwardingMessage] + self.fwd_orphan_channels = [] # type: List[GossipForwardingMessage] + self.fwd_channel_updates = [] # type: List[GossipForwardingMessage] + self.fwd_node_announcements = [] # type: List[GossipForwardingMessage] + self.data_loaded = asyncio.Event() self.network = network # only for callback @@ -487,6 +508,9 @@ class ChannelDB(SqlDB): self._update_num_policies_for_chan(channel_info.short_channel_id) if 'raw' in msg: self._db_save_channel(channel_info.short_channel_id, msg['raw']) + with self.forwarding_lock: + if fwd_msg := GossipForwardingMessage.from_payload(msg): + self.fwd_channels.append(fwd_msg) def policy_changed(self, old_policy: Policy, new_policy: Policy, verbose: bool) -> bool: changed = False @@ -555,6 +579,10 @@ class ChannelDB(SqlDB): if old_policy and not self.policy_changed(old_policy, policy, verbose): return UpdateStatus.UNCHANGED else: + if policy.message_flags & 0b10 == 0: # check if its `dont_forward` + with self.forwarding_lock: + if fwd_msg := GossipForwardingMessage.from_payload(payload): + self.fwd_channel_updates.append(fwd_msg) return UpdateStatus.GOOD def add_channel_updates(self, payloads, max_age=None) -> CategorizedChannelUpdates: @@ -667,7 +695,7 @@ class ChannelDB(SqlDB): # note: signatures have already been verified. if type(msg_payloads) is dict: msg_payloads = [msg_payloads] - new_nodes = {} + new_nodes = set() # type: Set[bytes] for msg_payload in msg_payloads: try: node_info, node_addresses = NodeInfo.from_msg(msg_payload) @@ -681,9 +709,7 @@ class ChannelDB(SqlDB): node = self._nodes.get(node_id) if node and node.timestamp >= node_info.timestamp: continue - node = new_nodes.get(node_id) - if node and node.timestamp >= node_info.timestamp: - continue + new_nodes.add(node_id) # save with self.lock: self._nodes[node_id] = node_info @@ -694,6 +720,9 @@ class ChannelDB(SqlDB): net_addr = NetAddress(addr.host, addr.port) self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0 self._db_save_node_addresses(node_addresses) + with self.forwarding_lock: + if fwd_msg := GossipForwardingMessage.from_payload(msg_payload): + self.fwd_node_announcements.append(fwd_msg) self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) self.update_counts() @@ -995,6 +1024,127 @@ class ChannelDB(SqlDB): return k raise Exception('node not found') + def clear_forwarding_gossip(self) -> None: + with self.forwarding_lock: + self.fwd_channels.clear() + self.fwd_channel_updates.clear() + self.fwd_node_announcements.clear() + + def filter_orphan_channel_anns( + self, channel_anns: List[GossipForwardingMessage] + ) -> Tuple[List, List]: + """Check if the channel announcements we want to forward have at least 1 update""" + to_forward_anns = [] + orphaned_channel_anns = [] + for channel in channel_anns: + if channel.scid is None: + continue + elif (channel.scid in self._chans_with_1_policies + or channel.scid in self._chans_with_2_policies): + to_forward_anns.append(channel) + continue + orphaned_channel_anns.append(channel) + return to_forward_anns, orphaned_channel_anns + + def set_fwd_channel_anns_ts(self, channel_anns: List[GossipForwardingMessage]) \ + -> List[GossipForwardingMessage]: + """Set the timestamps of the passed channel announcements from the corresponding policies""" + timestamped_chan_anns: List[GossipForwardingMessage] = [] + with self.lock: + policies = self._policies.copy() + channels = self._channels.copy() + + for chan_ann in channel_anns: + if chan_ann.timestamp is not None: + timestamped_chan_anns.append(chan_ann) + continue + + scid = chan_ann.scid + if (channel_info := channels.get(scid)) is None: + continue + + policy1 = policies.get((channel_info.node1_id, scid)) + policy2 = policies.get((channel_info.node2_id, scid)) + potential_timestamps = [] + for policy in [policy1, policy2]: + if policy is not None: + potential_timestamps.append(policy.timestamp) + if not potential_timestamps: + continue + chan_ann.timestamp = min(potential_timestamps) + timestamped_chan_anns.append(chan_ann) + return timestamped_chan_anns + + def get_forwarding_gossip_batch(self) -> List[GossipForwardingMessage]: + with self.forwarding_lock: + fwd_gossip = self.fwd_channel_updates + self.fwd_node_announcements + channel_anns = self.fwd_channels.copy() + self.clear_forwarding_gossip() + + fwd_chan_anns1, _ = self.filter_orphan_channel_anns(self.fwd_orphan_channels) + fwd_chan_anns2, self.fwd_orphan_channels = self.filter_orphan_channel_anns(channel_anns) + channel_anns = self.set_fwd_channel_anns_ts(fwd_chan_anns1 + fwd_chan_anns2) + return channel_anns + fwd_gossip + + def get_gossip_in_timespan(self, timespan: GossipTimestampFilter) \ + -> List[GossipForwardingMessage]: + """Return a list of gossip messages matching the requested timespan.""" + forwarding_gossip = [] + with self.lock: + chans = self._channels.copy() + policies = self._policies.copy() + nodes = self._nodes.copy() + + for short_id, chan in chans.items(): + # fetching the timestamp from the channel update (according to BOLT-07) + chan_up_n1 = policies.get((chan.node1_id, short_id)) + chan_up_n2 = policies.get((chan.node2_id, short_id)) + updates = [] + for policy in [chan_up_n1, chan_up_n2]: + if policy and policy.raw and timespan.in_range(policy.timestamp): + if policy.message_flags & 0b10 == 0: # check that its not "dont_forward" + updates.append(GossipForwardingMessage( + msg=policy.raw, + timestamp=policy.timestamp)) + if not updates or chan.raw is None: + continue + chan_ann_ts = min(update.timestamp for update in updates) + channel_announcement = GossipForwardingMessage(msg=chan.raw, timestamp=chan_ann_ts) + forwarding_gossip.extend([channel_announcement] + updates) + + for node_ann in nodes.values(): + if timespan.in_range(node_ann.timestamp) and node_ann.raw: + forwarding_gossip.append(GossipForwardingMessage( + msg=node_ann.raw, + timestamp=node_ann.timestamp)) + return forwarding_gossip + + def get_channels_in_range(self, first_blocknum: int, number_of_blocks: int) -> List[ShortChannelID]: + with self.lock: + channels = self._channels.copy() + scids: List[ShortChannelID] = [] + for scid in channels: + if first_blocknum <= scid.block_height < first_blocknum + number_of_blocks: + scids.append(scid) + scids.sort() + return scids + + def get_gossip_for_scid_request(self, scid: ShortChannelID) -> List[bytes]: + requested_gossip = [] + + chan_ann = self._channels.get(scid) + if not chan_ann or not chan_ann.raw: + return [] + chan_up1 = self._policies.get((chan_ann.node1_id, scid)) + chan_up2 = self._policies.get((chan_ann.node2_id, scid)) + node_ann1 = self._nodes.get(chan_ann.node1_id) + node_ann2 = self._nodes.get(chan_ann.node2_id) + + for msg in [chan_ann, chan_up1, chan_up2, node_ann1, node_ann2]: + if msg and msg.raw: + requested_gossip.append(msg.raw) + return requested_gossip + def to_dict(self) -> dict: """ Generates a graph representation in terms of a dictionary. diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index afe96b5cb..0810d2812 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict import asyncio import os import time -from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable, Awaitable +from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable, Awaitable, List from datetime import datetime import functools @@ -45,7 +45,8 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConf RemoteMisbehaving, ShortChannelID, IncompatibleLightningFeatures, derive_payment_secret_from_payment_preimage, ChannelType, LNProtocolWarning, validate_features, - IncompatibleOrInsaneFeatures, FeeBudgetExceeded) + IncompatibleOrInsaneFeatures, FeeBudgetExceeded, + GossipForwardingMessage, GossipTimestampFilter) from .lnutil import FeeUpdate, channel_id_from_funding_tx, PaymentFeeBudget from .lnutil import serialize_htlc_key, Keypair from .lntransport import LNTransport, LNTransportBase, LightningPeerConnectionClosed, HandshakeFailed @@ -74,7 +75,9 @@ class Peer(Logger, EventListener): ORDERED_MESSAGES = ( 'accept_channel', 'funding_signed', 'funding_created', 'accept_channel', 'closing_signed') SPAMMY_MESSAGES = ( - 'ping', 'pong', 'channel_announcement', 'node_announcement', 'channel_update',) + 'ping', 'pong', 'channel_announcement', 'node_announcement', 'channel_update', + 'gossip_timestamp_filter', 'reply_channel_range', 'query_channel_range', + 'query_short_channel_ids', 'reply_short_channel_ids', 'reply_short_channel_ids_end') DELAY_INC_MSG_PROCESSING_SLEEP = 0.01 @@ -106,6 +109,8 @@ class Peer(Logger, EventListener): self.reply_channel_range = asyncio.Queue() # gossip uses a single queue to preserve message order self.gossip_queue = asyncio.Queue() + self.gossip_timestamp_filter = None # type: Optional[GossipTimestampFilter] + self.outgoing_gossip_reply = False # type: bool self.ordered_message_queues = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue] # for messages that are ordered self.temp_id_to_id = {} # type: Dict[bytes, Optional[bytes]] # to forward error messages self.funding_created_sent = set() # for channels in PREOPENING @@ -168,8 +173,7 @@ class Peer(Logger, EventListener): await self.transport.handshake() self.logger.info(f"handshake done for {self.transport.peer_addr or self.pubkey.hex()}") features = self.features.for_init_message() - b = int.bit_length(features) - flen = b // 8 + int(bool(b % 8)) + flen = features.min_len() self.send_message( "init", gflen=0, flen=flen, features=features, @@ -240,6 +244,7 @@ class Peer(Logger, EventListener): # raw message is needed to check signature if message_type in ['node_announcement', 'channel_announcement', 'channel_update']: payload['raw'] = message + payload['sender_node_id'] = self.pubkey # note: the message handler might be async or non-async. In either case, by default, # we wait for it to complete before we return, i.e. before the next message is processed. if asyncio.iscoroutinefunction(f): @@ -294,7 +299,7 @@ class Peer(Logger, EventListener): return raise GracefulDisconnect - async def send_warning(self, channel_id: bytes, message: str = None, *, close_connection=False): + def send_warning(self, channel_id: bytes, message: str = None, *, close_connection=False): """Sends a warning and disconnects if close_connection. Note: @@ -313,7 +318,7 @@ class Peer(Logger, EventListener): if close_connection: raise GracefulDisconnect - async def send_error(self, channel_id: bytes, message: str = None, *, force_close_channel=False): + def send_error(self, channel_id: bytes, message: str = None, *, force_close_channel=False): """Sends an error message and force closes the channel. Note: @@ -406,6 +411,42 @@ class Peer(Logger, EventListener): if not self.lnworker.uses_trampoline(): self.gossip_queue.put_nowait(('channel_update', payload)) + def on_query_channel_range(self, payload): + if self.lnworker == self.lnworker.network.lngossip or not self._should_forward_gossip(): + return + if not self._is_valid_channel_range_query(payload): + return self.send_warning(bytes(32), "received invalid query_channel_range") + if self.outgoing_gossip_reply: + return self.send_warning(bytes(32), "received multiple queries at the same time") + self.outgoing_gossip_reply = True + self.gossip_queue.put_nowait(('query_channel_range', payload)) + + def on_query_short_channel_ids(self, payload): + if self.lnworker == self.lnworker.network.lngossip or not self._should_forward_gossip(): + return + if self.outgoing_gossip_reply: + return self.send_warning(bytes(32), "received multiple queries at the same time") + if not self._is_valid_short_channel_id_query(payload): + return self.send_warning(bytes(32), "invalid query_short_channel_ids") + self.outgoing_gossip_reply = True + self.gossip_queue.put_nowait(('query_short_channel_ids', payload)) + + def on_gossip_timestamp_filter(self, payload): + if self._should_forward_gossip(): + self.set_gossip_timestamp_filter(payload) + + def set_gossip_timestamp_filter(self, payload: dict) -> None: + """Set the gossip_timestamp_filter for this peer. If the peer requested historical gossip, + the request is put on the queue, otherwise only the forwarding loop will check the filter""" + if payload.get('chain_hash') != constants.net.rev_genesis_bytes(): + return + filter = GossipTimestampFilter.from_payload(payload) + self.gossip_timestamp_filter = filter + self.logger.debug(f"got gossip_ts_filter from peer {self.pubkey.hex()}: " + f"{str(self.gossip_timestamp_filter)}") + if filter and not filter.only_forwarding: + self.gossip_queue.put_nowait(('gossip_timestamp_filter', None)) + def maybe_save_remote_update(self, payload): if not self.channels: return @@ -424,6 +465,8 @@ class Peer(Logger, EventListener): # This code assumes gossip_queries is set. BOLT7: "if the # gossip_queries feature is negotiated, [a node] MUST NOT # send gossip it did not generate itself" + # NOTE: The definition of gossip_queries changed + # https://github.com/lightning/bolts/commit/fce8bab931674a81a9ea895c9e9162e559e48a65 short_channel_id = ShortChannelID(payload['short_channel_id']) self.logger.info(f'received orphan channel update {short_channel_id}') self.orphan_channel_updates[short_channel_id] = payload @@ -462,13 +505,14 @@ class Peer(Logger, EventListener): @handle_disconnect async def main_loop(self): async with self.taskgroup as group: - await group.spawn(self._message_loop()) await group.spawn(self.htlc_switch()) - await group.spawn(self.query_gossip()) - await group.spawn(self.process_gossip()) - await group.spawn(self.send_own_gossip()) + await group.spawn(self._message_loop()) + await group.spawn(self._query_gossip()) + await group.spawn(self._process_gossip()) + await group.spawn(self._send_own_gossip()) + await group.spawn(self._forward_gossip()) - async def process_gossip(self): + async def _process_gossip(self): while True: await asyncio.sleep(5) if not self.network.lngossip: @@ -484,6 +528,12 @@ class Peer(Logger, EventListener): chan_upds.append(payload) elif name == 'node_announcement': node_anns.append(payload) + elif name == 'query_channel_range': + await self.taskgroup.spawn(self._send_reply_channel_range(payload)) + elif name == 'query_short_channel_ids': + await self.taskgroup.spawn(self._send_reply_short_channel_ids(payload)) + elif name == 'gossip_timestamp_filter': + await self.taskgroup.spawn(self._handle_historical_gossip_request()) else: raise Exception('unknown message') if self.gossip_queue.empty(): @@ -491,7 +541,7 @@ class Peer(Logger, EventListener): if self.network.lngossip: await self.network.lngossip.process_gossip(chan_anns, node_anns, chan_upds) - async def send_own_gossip(self): + async def _send_own_gossip(self): if self.lnworker == self.lnworker.network.lngossip: return await asyncio.sleep(10) @@ -505,18 +555,76 @@ class Peer(Logger, EventListener): self.maybe_send_channel_announcement(chan) await asyncio.sleep(600) - async def query_gossip(self): + def _should_forward_gossip(self) -> bool: + if (self.network.lngossip != self.lnworker + and not self.lnworker.uses_trampoline() + and self.features.supports(LnFeatures.GOSSIP_QUERIES_REQ)): + return True + return False + + async def _forward_gossip(self): + if not self._should_forward_gossip(): + return + + async def send_new_gossip_with_semaphore(gossip: List[GossipForwardingMessage]): + async with self.network.lngossip.gossip_request_semaphore: + sent = await self._send_gossip_messages(gossip) + if sent > 0: + self.logger.debug(f"forwarded {sent} gossip messages to {self.pubkey.hex()}") + + lngossip = self.network.lngossip + last_gossip_batch_ts = 0 + while True: + await asyncio.sleep(10) + if not self.gossip_timestamp_filter: + continue # peer didn't request gossip + + new_gossip, last_lngossip_refresh_ts = await lngossip.get_forwarding_gossip() + if not last_lngossip_refresh_ts > last_gossip_batch_ts: + continue # no new batch available + last_gossip_batch_ts = last_lngossip_refresh_ts + + await self.taskgroup.spawn(send_new_gossip_with_semaphore(new_gossip)) + + async def _handle_historical_gossip_request(self): + """Called when a peer requests historical gossip with a gossip_timestamp_filter query.""" + filter = self.gossip_timestamp_filter + if not self._should_forward_gossip() or not filter or filter.only_forwarding: + return + async with self.network.lngossip.gossip_request_semaphore: + requested_gossip = self.lnworker.channel_db.get_gossip_in_timespan(filter) + filter.only_forwarding = True + sent = await self._send_gossip_messages(requested_gossip) + if sent > 0: + self.logger.debug(f"forwarded {sent} historical gossip messages to {self.pubkey.hex()}") + + async def _send_gossip_messages(self, messages: List[GossipForwardingMessage]) -> int: + amount_sent = 0 + for msg in messages: + if self.gossip_timestamp_filter.in_range(msg.timestamp) \ + and self.pubkey != msg.sender_node_id: + await self.transport.send_bytes_and_drain(msg.msg) + amount_sent += 1 + if amount_sent % 250 == 0: + # this can be a lot of messages, completely blocking the event loop + await asyncio.sleep(self.DELAY_INC_MSG_PROCESSING_SLEEP) + return amount_sent + + async def _query_gossip(self): try: await util.wait_for2(self.initialized, LN_P2P_NETWORK_TIMEOUT) except Exception as e: raise GracefulDisconnect(f"Failed to initialize: {e!r}") from e if self.lnworker == self.lnworker.network.lngossip: + if not self.their_features.supports(LnFeatures.GOSSIP_QUERIES_OPT): + raise GracefulDisconnect("remote does not support gossip_queries, which we need") try: ids, complete = await util.wait_for2(self.get_channel_range(), LN_P2P_NETWORK_TIMEOUT) except asyncio.TimeoutError as e: raise GracefulDisconnect("query_channel_range timed out") from e self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete)) await self.lnworker.add_new_ids(ids) + self.request_gossip(int(time.time())) while True: todo = self.lnworker.get_ids_to_query() if not todo: @@ -524,6 +632,68 @@ class Peer(Logger, EventListener): continue await self.get_short_channel_ids(todo) + @staticmethod + def _is_valid_channel_range_query(payload: dict) -> bool: + if payload.get('chain_hash') != constants.net.rev_genesis_bytes(): + return False + if payload.get('first_blocknum', -1) < constants.net.BLOCK_HEIGHT_FIRST_LIGHTNING_CHANNELS: + return False + if payload.get('number_of_blocks', 0) < 1: + return False + return True + + def _is_valid_short_channel_id_query(self, payload: dict) -> bool: + if payload.get('chain_hash') != constants.net.rev_genesis_bytes(): + return False + enc_short_ids = payload['encoded_short_ids'] + if enc_short_ids[0] != 0: + self.logger.debug(f"got query_short_channel_ids with invalid encoding: {repr(enc_short_ids[0])}") + return False + if (len(enc_short_ids) - 1) % 8 != 0: + self.logger.debug(f"got query_short_channel_ids with invalid length") + return False + return True + + async def _send_reply_channel_range(self, payload: dict): + """https://github.com/lightning/bolts/blob/acd383145dd8c3fecd69ce94e4a789767b984ac0/07-routing-gossip.md#requirements-5""" + first_blockheight: int = payload['first_blocknum'] + + async with self.network.lngossip.gossip_request_semaphore: + sorted_scids: List[ShortChannelID] = self.lnworker.channel_db.get_channels_in_range( + first_blockheight, + payload['number_of_blocks'] + ) + self.logger.debug(f"reply_channel_range to request " + f"first_height={first_blockheight}, " + f"num_blocks={payload['number_of_blocks']}, " + f"sending {len(sorted_scids)} scids") + + complete: bool = False + while not complete: + # create a 64800 byte chunk of skids, split the remaining scids + encoded_scids, sorted_scids = b''.join(sorted_scids[:8100]), sorted_scids[8100:] + complete = len(sorted_scids) == 0 # if there are no scids remaining we are done + # number of blocks covered by the scids in this chunk + if complete: + # LAST MESSAGE MUST have first_blocknum plus number_of_blocks equal or greater than + # the query_channel_range first_blocknum plus number_of_blocks. + number_of_blocks = ((payload['first_blocknum'] + payload['number_of_blocks']) + - first_blockheight) + else: + # we cover the range until the height of the first scid in the next chunk + number_of_blocks = sorted_scids[0].block_height - first_blockheight + self.send_message('reply_channel_range', + chain_hash=constants.net.rev_genesis_bytes(), + first_blocknum=first_blockheight, + number_of_blocks=number_of_blocks, + sync_complete=complete, + len=1+len(encoded_scids), + encoded_short_ids=b'\x00' + encoded_scids) + if not complete: + first_blockheight = sorted_scids[0].block_height + await asyncio.sleep(self.DELAY_INC_MSG_PROCESSING_SLEEP) + self.outgoing_gossip_reply = False + async def get_channel_range(self): first_block = constants.net.BLOCK_HEIGHT_FIRST_LIGHTNING_CHANNELS num_blocks = self.lnworker.network.get_local_height() - first_block @@ -544,6 +714,7 @@ class Peer(Logger, EventListener): # on_reply_channel_range. >>> first_block 497000, num_blocks 79038, num_ids 8000, complete False # on_reply_channel_range. >>> first_block 497000, num_blocks 79038, num_ids 8000, complete False # on_reply_channel_range. >>> first_block 497000, num_blocks 79038, num_ids 5344, complete True + # ADDENDUM (01/2025): now it's 'MUST set sync_complete to false if this is not the final reply_channel_range.' while True: index, num, complete, _ids = await self.reply_channel_range.get() ids.update(_ids) @@ -599,9 +770,34 @@ class Peer(Logger, EventListener): complete = bool(int.from_bytes(payload['sync_complete'], 'big')) encoded = payload['encoded_short_ids'] ids = self.decode_short_ids(encoded) - #self.logger.info(f"on_reply_channel_range. >>> first_block {first}, num_blocks {num}, num_ids {len(ids)}, complete {repr(payload['complete'])}") + # self.logger.info(f"on_reply_channel_range. >>> first_block {first}, num_blocks {num}, " + # f"num_ids {len(ids)}, complete {complete}") self.reply_channel_range.put_nowait((first, num, complete, ids)) + async def _send_reply_short_channel_ids(self, payload: dict): + async with self.network.lngossip.gossip_request_semaphore: + requested_scids = payload['encoded_short_ids'] + decoded_scids = [ShortChannelID.normalize(scid) + for scid in self.decode_short_ids(requested_scids)] + self.logger.debug(f"serving query_short_channel_ids request: " + f"requested {len(decoded_scids)} scids") + chan_db = self.lnworker.channel_db + response: Set[bytes] = set() + for scid in decoded_scids: + requested_msgs = chan_db.get_gossip_for_scid_request(scid) + response.update(requested_msgs) + self.logger.debug(f"found {len(response)} gossip messages to serve scid request") + for index, msg in enumerate(response): + await self.transport.send_bytes_and_drain(msg) + if index % 250 == 0: + await asyncio.sleep(self.DELAY_INC_MSG_PROCESSING_SLEEP) + self.send_message( + 'reply_short_channel_ids_end', + chain_hash=constants.net.rev_genesis_bytes(), + full_information=self.network.lngossip.is_synced() + ) + self.outgoing_gossip_reply = False + async def get_short_channel_ids(self, ids): self.logger.info(f'Querying {len(ids)} short_channel_ids') assert not self.querying.is_set() @@ -982,7 +1178,7 @@ class Peer(Logger, EventListener): try: chan.receive_new_commitment(remote_sig, []) except LNProtocolWarning as e: - await self.send_warning(channel_id, message=str(e), close_connection=True) + self.send_warning(channel_id, message=str(e), close_connection=True) chan.open_with_first_pcp(remote_per_commitment_point, remote_sig) chan.set_state(ChannelState.OPENING) if zeroconf: @@ -1179,7 +1375,7 @@ class Peer(Logger, EventListener): try: chan.receive_new_commitment(remote_sig, []) except LNProtocolWarning as e: - await self.send_warning(channel_id, message=str(e), close_connection=True) + self.send_warning(channel_id, message=str(e), close_connection=True) sig_64, _ = chan.sign_next_commitment() self.send_message('funding_signed', channel_id=channel_id, @@ -1509,8 +1705,7 @@ class Peer(Logger, EventListener): timestamp = int(time.time()) node_id = privkey_to_pubkey(self.privkey) features = self.features.for_node_announcement() - b = int.bit_length(features) - flen = b // 8 + int(bool(b % 8)) + flen = features.min_len() rgb_color = bytes.fromhex('000000') alias = bytes(alias, 'utf8') alias += bytes(32 - len(alias)) @@ -2444,7 +2639,7 @@ class Peer(Logger, EventListener): # BOLT-02 check if they use the upfront shutdown script they advertised if self.is_upfront_shutdown_script() and their_upfront_scriptpubkey: if not (their_scriptpubkey == their_upfront_scriptpubkey): - await self.send_warning( + self.send_warning( chan.channel_id, "remote didn't use upfront shutdown script it committed to in channel opening", close_connection=True) @@ -2455,7 +2650,7 @@ class Peer(Logger, EventListener): elif match_script_against_template(their_scriptpubkey, transaction.SCRIPTPUBKEY_TEMPLATE_WITNESS_V0): pass else: - await self.send_warning( + self.send_warning( chan.channel_id, f'scriptpubkey in received shutdown message does not conform to any template: {their_scriptpubkey.hex()}', close_connection=True) diff --git a/electrum/lntransport.py b/electrum/lntransport.py index 0f40daa42..b652beeda 100644 --- a/electrum/lntransport.py +++ b/electrum/lntransport.py @@ -199,6 +199,9 @@ class LNTransportBase: privkey: bytes peer_addr: Optional[LNPeerAddr] = None + def __init__(self): + self.drain_write_lock = asyncio.Lock() + def name(self) -> str: pubkey = self.remote_pubkey() pubkey_hex = pubkey.hex() if pubkey else pubkey @@ -218,6 +221,12 @@ class LNTransportBase: assert len(c) == len(msg) + 16 self.writer.write(lc+c) + async def send_bytes_and_drain(self, msg: bytes) -> None: + """Should be used when possible (in async scope), to avoid memory exhaustion.""" + async with self.drain_write_lock: + self.send_bytes(msg) + await self.writer.drain() + async def read_messages(self): buffer = bytearray() while True: diff --git a/electrum/lnutil.py b/electrum/lnutil.py index a6b14fb5b..3c7814c25 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -9,6 +9,7 @@ from collections import namedtuple, defaultdict from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict, Set, Sequence import re import sys +import time import electrum_ecc as ecc from electrum_ecc import CURVE_ORDER, ecdsa_sig64_from_der_sig, ECPubkey, string_to_number @@ -1509,6 +1510,10 @@ class LnFeatures(IntFlag): features |= (1 << flag) return features + def min_len(self) -> int: + b = int.bit_length(self) + return b // 8 + int(bool(b % 8)) + def supports(self, feature: 'LnFeatures') -> bool: """Returns whether given feature is enabled. @@ -1634,6 +1639,56 @@ def get_ln_flag_pair_of_bit(flag_bit: int) -> int: return flag_bit - 1 +class GossipTimestampFilter: + def __init__(self, first_timestamp: int, timestamp_range: int): + self.first_timestamp = first_timestamp + self.timestamp_range = timestamp_range + # True once we sent them the requested gossip and only forward + self.only_forwarding = False + if first_timestamp >= int(time.time()) - 20: + self.only_forwarding = True + + def __str__(self): + return (f"GossipTimestampFilter | first_timestamp={self.first_timestamp} | " + f"timestamp_range={self.timestamp_range}") + + def in_range(self, timestamp: int) -> bool: + return self.first_timestamp <= timestamp < self.first_timestamp + self.timestamp_range + + @classmethod + def from_payload(cls, payload) -> Optional['GossipTimestampFilter']: + try: + first_timestamp = payload['first_timestamp'] + timestamp_range = payload['timestamp_range'] + except KeyError: + return None + if first_timestamp >= 0xFFFFFFFF: + return None + return cls(first_timestamp, timestamp_range) + + +class GossipForwardingMessage: + def __init__(self, + msg: bytes, + scid: Optional[ShortChannelID] = None, + timestamp: Optional[int] = None, + sender_node_id: Optional[bytes] = None): + self.scid: Optional[ShortChannelID] = scid + self.sender_node_id: Optional[bytes] = sender_node_id + self.msg = msg + self.timestamp = timestamp + + @classmethod + def from_payload(cls, payload: dict) -> Optional['GossipForwardingMessage']: + try: + msg = payload['raw'] + scid = ShortChannelID.normalize(payload.get('short_channel_id')) + sender_node_id = payload.get('sender_node_id') + timestamp = payload.get('timestamp') + except KeyError: + return None + return cls(msg, scid, timestamp, sender_node_id) + def list_enabled_ln_feature_bits(features: int) -> tuple[int, ...]: """Returns a list of enabled feature bits. If both opt and req are set, only req will be included in the result.""" diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 952f1639b..b24ee5907 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -60,7 +60,7 @@ from .lnutil import ( LnKeyFamily, LOCAL, REMOTE, MIN_FINAL_CLTV_DELTA_FOR_INVOICE, SENT, RECEIVED, HTLCOwner, UpdateAddHtlc, LnFeatures, ShortChannelID, HtlcLog, NoPathFound, InvalidGossipMsg, FeeBudgetExceeded, ImportedChannelBackupStorage, OnchainChannelBackupStorage, ln_compare_features, IncompatibleLightningFeatures, PaymentFeeBudget, - NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE + NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage ) from .lnonion import decode_onion_error, OnionFailureCode, OnionRoutingFailure, OnionPacket from .lnmsg import decode_msg @@ -162,7 +162,6 @@ LNWALLET_FEATURES = ( BASE_FEATURES | LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ | LnFeatures.OPTION_STATIC_REMOTEKEY_REQ - | LnFeatures.GOSSIP_QUERIES_REQ | LnFeatures.VAR_ONION_REQ | LnFeatures.PAYMENT_SECRET_REQ | LnFeatures.BASIC_MPP_OPT @@ -175,8 +174,10 @@ LNWALLET_FEATURES = ( LNGOSSIP_FEATURES = ( BASE_FEATURES - | LnFeatures.GOSSIP_QUERIES_OPT + # LNGossip doesn't serve gossip but weirdly have to signal so + # that peers satisfy our queries | LnFeatures.GOSSIP_QUERIES_REQ + | LnFeatures.GOSSIP_QUERIES_OPT ) @@ -290,7 +291,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): peer_addr = LNPeerAddr(host, port, node_id) self._trying_addr_now(peer_addr) self.logger.info(f"adding peer {peer_addr}") - if node_id == self.node_keypair.pubkey: + if node_id == self.node_keypair.pubkey or self.is_our_lnwallet(node_id): raise ErrorAddingPeer("cannot connect to self") transport = LNTransport(self.node_keypair.privkey, peer_addr, e_proxy=ESocksProxy.from_network_settings(self.network)) @@ -327,6 +328,14 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): def num_peers(self) -> int: return sum([p.is_initialized() for p in self.peers.values()]) + def is_our_lnwallet(self, node_id: bytes) -> bool: + """Check if node_id is one of our own wallets""" + wallets = self.network.daemon.get_wallets() + for wallet in wallets.values(): + if wallet.lnworker and wallet.lnworker.node_keypair.pubkey == node_id: + return True + return False + def start_network(self, network: 'Network'): assert network assert self.network is None, "already started" @@ -511,6 +520,12 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): class LNGossip(LNWorker): + """The LNGossip class is a separate, unannounced Lightning node with random id that is just querying + gossip from other nodes. The LNGossip node does not satisfy gossip queries, this is done by the + LNWallet class(es). LNWallets are the advertised nodes used for actual payments and only satisfy + peer queries without fetching gossip themselves. This separation is done so that gossip can be queried + independently of the active LNWallets. LNGossip keeps a curated batch of gossip in _forwarding_gossip + that is fetched by the LNWallets for regular forwarding.""" max_age = 14*24*3600 LOGGING_SHORTCUT = 'g' @@ -521,12 +536,17 @@ class LNGossip(LNWorker): node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY) LNWorker.__init__(self, node_keypair, LNGOSSIP_FEATURES, config=config) self.unknown_ids = set() + self._forwarding_gossip = [] # type: List[GossipForwardingMessage] + self._last_gossip_batch_ts = 0 # type: int + self._forwarding_gossip_lock = asyncio.Lock() + self.gossip_request_semaphore = asyncio.Semaphore(5) def start_network(self, network: 'Network'): super().start_network(network) for coro in [ self._maintain_connectivity(), self.maintain_db(), + self._maintain_forwarding_gossip() ]: tg_coro = self.taskgroup.spawn(coro) asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop) @@ -539,6 +559,20 @@ class LNGossip(LNWorker): self.channel_db.prune_orphaned_channels() await asyncio.sleep(120) + async def _maintain_forwarding_gossip(self): + await self.channel_db.data_loaded.wait() + await self.wait_for_sync() + while True: + async with self._forwarding_gossip_lock: + self._forwarding_gossip = self.channel_db.get_forwarding_gossip_batch() + self._last_gossip_batch_ts = int(time.time()) + self.logger.debug(f"{len(self._forwarding_gossip)} gossip messages available to forward") + await asyncio.sleep(60) + + async def get_forwarding_gossip(self) -> tuple[List[GossipForwardingMessage], int]: + async with self._forwarding_gossip_lock: + return self._forwarding_gossip, self._last_gossip_batch_ts + async def add_new_ids(self, ids: Iterable[bytes]): known = self.channel_db.get_channel_ids() new = set(ids) - set(known) @@ -563,12 +597,19 @@ class LNGossip(LNWorker): return None, None, None nchans_with_0p, nchans_with_1p, nchans_with_2p = self.channel_db.get_num_channels_partitioned_by_policy_count() num_db_channels = nchans_with_0p + nchans_with_1p + nchans_with_2p + num_nodes = self.channel_db.num_nodes + num_nodes_associated_to_chans = max(len(self.channel_db._channels_for_node.keys()), 1) # some channels will never have two policies (only one is in gossip?...) # so if we have at least 1 policy for a channel, we consider that channel "complete" here current_est = num_db_channels - nchans_with_0p total_est = len(self.unknown_ids) + num_db_channels - progress = current_est / total_est if total_est and current_est else 0 + progress_chans = current_est / total_est if total_est and current_est else 0 + # consider that we got at least 10% of the node anns of node ids we know about + progress_nodes = min((num_nodes / num_nodes_associated_to_chans) * 10, 1) + progress = (progress_chans * 3 + progress_nodes) / 4 # weigh the channel progress higher + # self.logger.debug(f"Sync process chans: {progress_chans} | Progress nodes: {progress_nodes} | " + # f"Total progress: {progress} | NUM_NODES: {num_nodes} / {num_nodes_associated_to_chans}") progress_percent = (1.0 / 0.95 * progress) * 100 progress_percent = min(progress_percent, 100) progress_percent = round(progress_percent) @@ -582,8 +623,8 @@ class LNGossip(LNWorker): # note: we run in the originating peer's TaskGroup, so we can safely raise here # and disconnect only from that peer await self.channel_db.data_loaded.wait() - self.logger.debug(f'process_gossip {len(chan_anns)} {len(node_anns)} {len(chan_upds)}') - + self.logger.debug(f'process_gossip ca: {len(chan_anns)} na: {len(node_anns)} ' + f'cu: {len(chan_upds)}') # channel announcements def process_chan_anns(): for payload in chan_anns: @@ -610,6 +651,24 @@ class LNGossip(LNWorker): if categorized_chan_upds.good: self.logger.debug(f'process_gossip: {len(categorized_chan_upds.good)}/{len(chan_upds)}') + def is_synced(self) -> bool: + _, _, percentage_synced = self.get_sync_progress_estimate() + if percentage_synced is not None and percentage_synced >= 100: + return True + return False + + async def wait_for_sync(self, times_to_check: int = 3): + """Check if we have 100% sync progress `times_to_check` times in a row (because the + estimate often jumps back after some seconds when doing initial sync).""" + while True: + if self.is_synced(): + times_to_check -= 1 + if times_to_check <= 0: + return + await asyncio.sleep(10) + # flush the gossip queue so we don't forward old gossip after sync is complete + self.channel_db.get_forwarding_gossip_batch() + class PaySession(Logger): def __init__( @@ -771,6 +830,8 @@ class LNWallet(LNWorker): features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT if self.config.ACCEPT_ZEROCONF_CHANNELS: features |= LnFeatures.OPTION_ZEROCONF_OPT + if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP: + features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch LNWorker.__init__(self, self.node_keypair, features, config=self.config) self.lnwatcher = None self.lnrater: LNRater = None diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 43855f1c0..1ce1dfcc7 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -1247,7 +1247,7 @@ class TestPeerDirect(TestPeer): async def action(): await util.wait_for2(p1.initialized, 1) await util.wait_for2(p2.initialized, 1) - await p1.send_warning(alice_channel.channel_id, 'be warned!', close_connection=True) + p1.send_warning(alice_channel.channel_id, 'be warned!', close_connection=True) gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) with self.assertRaises(GracefulDisconnect): await gath @@ -1259,7 +1259,7 @@ class TestPeerDirect(TestPeer): async def action(): await util.wait_for2(p1.initialized, 1) await util.wait_for2(p2.initialized, 1) - await p1.send_error(alice_channel.channel_id, 'some error happened!', force_close_channel=True) + p1.send_error(alice_channel.channel_id, 'some error happened!', force_close_channel=True) assert alice_channel.is_closed() gath.cancel() gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())