1
0

Merge pull request #10153 from SomberNight/202508_lnpeer_rate_limits

lnpeer: add some rate-limits
This commit is contained in:
ghost43
2025-08-20 12:28:50 +00:00
committed by GitHub
2 changed files with 99 additions and 26 deletions

View File

@@ -12,6 +12,7 @@ import time
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable, Awaitable, List
from datetime import datetime
import functools
from functools import partial
import electrum_ecc as ecc
from electrum_ecc import ecdsa_sig64_from_r_and_s, ecdsa_der_sig_from_ecdsa_sig64, ECPubkey
@@ -78,6 +79,8 @@ class Peer(Logger, EventListener):
'query_short_channel_ids', 'reply_short_channel_ids', 'reply_short_channel_ids_end')
DELAY_INC_MSG_PROCESSING_SLEEP = 0.01
RECV_GOSSIP_QUEUE_SOFT_MAXSIZE = 2000
RECV_GOSSIP_QUEUE_HARD_MAXSIZE = 5000
def __init__(
self,
@@ -104,12 +107,13 @@ class Peer(Logger, EventListener):
assert self.node_ids[0] != self.node_ids[1]
self.last_message_time = 0
self.pong_event = asyncio.Event()
self.reply_channel_range = asyncio.Queue()
self.reply_channel_range = None # type: Optional[asyncio.Queue]
# gossip uses a single queue to preserve message order
self.gossip_queue = asyncio.Queue()
self.gossip_timestamp_filter = None # type: Optional[GossipTimestampFilter]
self.recv_gossip_queue = asyncio.Queue(maxsize=self.RECV_GOSSIP_QUEUE_HARD_MAXSIZE)
self.our_gossip_timestamp_filter = None # type: Optional[GossipTimestampFilter]
self.their_gossip_timestamp_filter = None # type: Optional[GossipTimestampFilter]
self.outgoing_gossip_reply = False # type: bool
self.ordered_message_queues = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue] # for messages that are ordered
self.ordered_message_queues = defaultdict(partial(asyncio.Queue, maxsize=10)) # type: Dict[bytes, asyncio.Queue] # for messages that are ordered
self.temp_id_to_id = {} # type: Dict[bytes, Optional[bytes]] # to forward error messages
self.funding_created_sent = set() # for channels in PREOPENING
self.funding_signed_sent = set() # for channels in PREOPENING
@@ -225,6 +229,12 @@ class Peer(Logger, EventListener):
return
if message_type in self.ORDERED_MESSAGES:
chan_id = payload.get('channel_id') or payload["temporary_channel_id"]
if (
chan_id not in self.channels
and chan_id not in self.temp_id_to_id
and chan_id not in self.temp_id_to_id.values()
):
raise Exception(f"received {message_type} for unknown {chan_id.hex()=}")
self.ordered_message_queues[chan_id].put_nowait((message_type, payload))
else:
if message_type not in ('error', 'warning') and 'channel_id' in payload:
@@ -399,17 +409,26 @@ class Peer(Logger, EventListener):
self.maybe_set_initialized()
def on_node_announcement(self, payload):
if not self.lnworker.uses_trampoline():
self.gossip_queue.put_nowait(('node_announcement', payload))
if self.lnworker.uses_trampoline():
return
if self.our_gossip_timestamp_filter is None:
return # why is the peer sending this? should we disconnect?
self.recv_gossip_queue.put_nowait(('node_announcement', payload))
def on_channel_announcement(self, payload):
if not self.lnworker.uses_trampoline():
self.gossip_queue.put_nowait(('channel_announcement', payload))
if self.lnworker.uses_trampoline():
return
if self.our_gossip_timestamp_filter is None:
return # why is the peer sending this? should we disconnect?
self.recv_gossip_queue.put_nowait(('channel_announcement', payload))
def on_channel_update(self, payload):
self.maybe_save_remote_update(payload)
if not self.lnworker.uses_trampoline():
self.gossip_queue.put_nowait(('channel_update', payload))
if self.lnworker.uses_trampoline():
return
if self.our_gossip_timestamp_filter is None:
return # why is the peer sending this? should we disconnect?
self.recv_gossip_queue.put_nowait(('channel_update', payload))
def on_query_channel_range(self, payload):
if self.lnworker == self.lnworker.network.lngossip or not self._should_forward_gossip():
@@ -419,7 +438,7 @@ class Peer(Logger, EventListener):
if self.outgoing_gossip_reply:
return self.send_warning(bytes(32), "received multiple queries at the same time")
self.outgoing_gossip_reply = True
self.gossip_queue.put_nowait(('query_channel_range', payload))
self.recv_gossip_queue.put_nowait(('query_channel_range', payload))
def on_query_short_channel_ids(self, payload):
if self.lnworker == self.lnworker.network.lngossip or not self._should_forward_gossip():
@@ -429,7 +448,7 @@ class Peer(Logger, EventListener):
if not self._is_valid_short_channel_id_query(payload):
return self.send_warning(bytes(32), "invalid query_short_channel_ids")
self.outgoing_gossip_reply = True
self.gossip_queue.put_nowait(('query_short_channel_ids', payload))
self.recv_gossip_queue.put_nowait(('query_short_channel_ids', payload))
def on_gossip_timestamp_filter(self, payload):
if self._should_forward_gossip():
@@ -441,11 +460,11 @@ class Peer(Logger, EventListener):
if payload.get('chain_hash') != constants.net.rev_genesis_bytes():
return
filter = GossipTimestampFilter.from_payload(payload)
self.gossip_timestamp_filter = filter
self.their_gossip_timestamp_filter = filter
self.logger.debug(f"got gossip_ts_filter from peer {self.pubkey.hex()}: "
f"{str(self.gossip_timestamp_filter)}")
f"{str(self.their_gossip_timestamp_filter)}")
if filter and not filter.only_forwarding:
self.gossip_queue.put_nowait(('gossip_timestamp_filter', None))
self.recv_gossip_queue.put_nowait(('gossip_timestamp_filter', None))
def maybe_save_remote_update(self, payload):
if not self.channels:
@@ -521,7 +540,7 @@ class Peer(Logger, EventListener):
chan_upds = []
node_anns = []
while True:
name, payload = await self.gossip_queue.get()
name, payload = await self.recv_gossip_queue.get()
if name == 'channel_announcement':
chan_anns.append(payload)
elif name == 'channel_update':
@@ -536,7 +555,7 @@ class Peer(Logger, EventListener):
await self.taskgroup.spawn(self._handle_historical_gossip_request())
else:
raise Exception('unknown message')
if self.gossip_queue.empty():
if self.recv_gossip_queue.empty():
break
if self.network.lngossip:
await self.network.lngossip.process_gossip(chan_anns, node_anns, chan_upds)
@@ -577,7 +596,7 @@ class Peer(Logger, EventListener):
last_gossip_batch_ts = 0
while True:
await asyncio.sleep(10)
if not self.gossip_timestamp_filter:
if not self.their_gossip_timestamp_filter:
continue # peer didn't request gossip
new_gossip, last_lngossip_refresh_ts = await lngossip.get_forwarding_gossip()
@@ -589,7 +608,7 @@ class Peer(Logger, EventListener):
async def _handle_historical_gossip_request(self):
"""Called when a peer requests historical gossip with a gossip_timestamp_filter query."""
filter = self.gossip_timestamp_filter
filter = self.their_gossip_timestamp_filter
if not self._should_forward_gossip() or not filter or filter.only_forwarding:
return
async with self.network.lngossip.gossip_request_semaphore:
@@ -603,7 +622,7 @@ class Peer(Logger, EventListener):
async def _send_gossip_messages(self, messages: List[GossipForwardingMessage]) -> int:
amount_sent = 0
for msg in messages:
if self.gossip_timestamp_filter.in_range(msg.timestamp) \
if self.their_gossip_timestamp_filter.in_range(msg.timestamp) \
and self.pubkey != msg.sender_node_id:
await self.transport.send_bytes_and_drain(msg.msg)
amount_sent += 1
@@ -697,6 +716,7 @@ class Peer(Logger, EventListener):
self.outgoing_gossip_reply = False
async def get_channel_range(self):
self.reply_channel_range = asyncio.Queue()
first_block = constants.net.BLOCK_HEIGHT_FIRST_LIGHTNING_CHANNELS
num_blocks = self.lnworker.network.get_local_height() - first_block
self.query_channel_range(first_block, num_blocks)
@@ -735,6 +755,7 @@ class Peer(Logger, EventListener):
a, b = intervals[0]
if a <= first_block and b >= first_block + num_blocks:
break
self.reply_channel_range = None
return ids, complete
def request_gossip(self, timestamp=0):
@@ -742,11 +763,17 @@ class Peer(Logger, EventListener):
self.logger.info('requesting whole channel graph')
else:
self.logger.info(f'requesting channel graph since {datetime.fromtimestamp(timestamp).isoformat()}')
timestamp_range = 0xFFFFFFFF
self.our_gossip_timestamp_filter = GossipTimestampFilter(
first_timestamp=timestamp,
timestamp_range=timestamp_range,
)
self.send_message(
'gossip_timestamp_filter',
chain_hash=constants.net.rev_genesis_bytes(),
first_timestamp=timestamp,
timestamp_range=b'\xff'*4)
timestamp_range=timestamp_range,
)
def query_channel_range(self, first_block, num_blocks):
self.logger.info(f'query channel range {first_block} {num_blocks}')
@@ -766,7 +793,7 @@ class Peer(Logger, EventListener):
ids = [decoded[i:i+8] for i in range(0, len(decoded), 8)]
return ids
def on_reply_channel_range(self, payload):
async def on_reply_channel_range(self, payload):
first = payload['first_blocknum']
num = payload['number_of_blocks']
complete = bool(int.from_bytes(payload['sync_complete'], 'big'))
@@ -774,6 +801,12 @@ class Peer(Logger, EventListener):
ids = self.decode_short_ids(encoded)
# self.logger.info(f"on_reply_channel_range. >>> first_block {first}, num_blocks {num}, "
# f"num_ids {len(ids)}, complete {complete}")
if self.reply_channel_range is None:
raise Exception("received 'reply_channel_range' without corresponding 'query_channel_range'")
while self.reply_channel_range.qsize() > 10:
# we block process_message until the queue gets consumed
self.logger.info("reply_channel_range queue is overflowing. sleeping...")
await asyncio.sleep(0.1)
self.reply_channel_range.put_nowait((first, num, complete, ids))
async def _send_reply_short_channel_ids(self, payload: dict):
@@ -830,6 +863,15 @@ class Peer(Logger, EventListener):
# rate-limit message-processing a bit, to make it harder
# for a single peer to bog down the event loop / cpu:
await asyncio.sleep(self.DELAY_INC_MSG_PROCESSING_SLEEP)
# If receiving too much gossip from this peer, we need to slow them down.
# note: if the gossip queue gets full, we will disconnect from them
# and throw away unprocessed gossip.
if self.recv_gossip_queue.qsize() > self.RECV_GOSSIP_QUEUE_SOFT_MAXSIZE:
sleep = self.recv_gossip_queue.qsize() / 1000
self.logger.debug(
f"message_loop sleeping due to getting much gossip. qsize={self.recv_gossip_queue.qsize()}. "
f"waiting for existing gossip data to be processed first.")
await asyncio.sleep(sleep)
def on_reply_short_channel_ids_end(self, payload):
self.querying.set()
@@ -1050,8 +1092,8 @@ class Peer(Logger, EventListener):
int.from_bytes(per_commitment_secret_first, 'big'))
# store the temp id now, so that it is recognized for e.g. 'error' messages
# TODO: this is never cleaned up; the dict grows unbounded until disconnect
self.temp_id_to_id[temp_channel_id] = None
self._cleanup_temp_channelids()
self.send_message(
"open_channel",
temporary_channel_id=temp_channel_id,
@@ -1270,8 +1312,8 @@ class Peer(Logger, EventListener):
feerate = payload['feerate_per_kw'] # note: we are not validating this
temp_chan_id = payload['temporary_channel_id']
# store the temp id now, so that it is recognized for e.g. 'error' messages
# TODO: this is never cleaned up; the dict grows unbounded until disconnect
self.temp_id_to_id[temp_chan_id] = None
self._cleanup_temp_channelids()
channel_opening_fee = open_channel_tlvs.get('channel_opening_fee') if open_channel_tlvs else None
if channel_opening_fee:
# todo check that the fee is reasonable
@@ -1413,6 +1455,15 @@ class Peer(Logger, EventListener):
self.send_channel_ready(chan)
self.lnworker.add_new_channel(chan)
def _cleanup_temp_channelids(self) -> None:
self.temp_id_to_id = {
tmp_id: chan_id for (tmp_id, chan_id) in self.temp_id_to_id.items()
if chan_id not in self.channels
}
if len(self.temp_id_to_id) > 25:
# which one of us is opening all these chans?! let's disconnect
raise Exception("temp_id_to_id is getting too large.")
async def request_force_close(self, channel_id: bytes):
"""Try to trigger the remote peer to force-close."""
await self.initialized

View File

@@ -202,6 +202,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
self.lock = threading.RLock()
self.node_keypair = node_keypair
self._peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer # needs self.lock
self._channelless_incoming_peers = set() # type: Set[bytes] # node_ids # needs self.lock
self.taskgroup = OldTaskGroup()
self.listen_server = None # type: Optional[asyncio.AbstractServer]
self.features = features
@@ -252,13 +253,15 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
return
addr = str(netaddr.host)
async def cb(reader, writer):
async def cb(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
transport = LNResponderTransport(self.node_keypair.privkey, reader, writer)
try:
node_id = await transport.handshake()
except Exception as e:
self.logger.info(f'handshake failure from incoming connection: {e!r}')
return
peername = writer.get_extra_info('peername')
self.logger.debug(f"handshake done for incoming peer: {peername=}, node_id={node_id.hex()}")
await self._add_peer_from_transport(node_id=node_id, transport=transport)
try:
self.listen_server = await asyncio.start_server(cb, addr, netaddr.port)
@@ -315,11 +318,29 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
# both keep trying to reconnect, resulting in neither being usable.
if existing_peer.is_initialized():
# give priority to the existing connection
return
transport.close()
return None
else:
# Use the new connection. (e.g. old peer might be an outgoing connection
# for an outdated host/port that will never connect)
existing_peer.close_and_cleanup()
# limit max number of incoming channel-less peers.
# what to do if limit is reached?
# - chosen strategy: we don't allow new connections.
# - drawback: attacker can use up all our slots
# - alternative: kick oldest channel-less peer
# - drawback: if many legit peers want to connect to us, we will keep kicking them
# in round-robin, and they will keep reconnecting. no stable state -> we self-DOS
# TODO make slots IP-based?
if isinstance(transport, LNResponderTransport):
assert node_id not in self._channelless_incoming_peers
chans = [chan for chan in self.channels_for_peer(node_id).values() if chan.is_funded()]
if not chans:
if len(self._channelless_incoming_peers) > 100:
transport.close()
return None
self._channelless_incoming_peers.add(node_id)
# checks done: we are adding this peer.
peer = Peer(self, node_id, transport)
assert node_id not in self._peers
self._peers[node_id] = peer
@@ -331,6 +352,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
peer2 = self._peers.get(peer.pubkey)
if peer2 is peer:
self._peers.pop(peer.pubkey)
self._channelless_incoming_peers.discard(peer.pubkey)
def num_peers(self) -> int:
return sum([p.is_initialized() for p in self.peers.values()])