Merge pull request #10153 from SomberNight/202508_lnpeer_rate_limits
lnpeer: add some rate-limits
This commit is contained in:
@@ -12,6 +12,7 @@ import time
|
||||
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable, Awaitable, List
|
||||
from datetime import datetime
|
||||
import functools
|
||||
from functools import partial
|
||||
|
||||
import electrum_ecc as ecc
|
||||
from electrum_ecc import ecdsa_sig64_from_r_and_s, ecdsa_der_sig_from_ecdsa_sig64, ECPubkey
|
||||
@@ -78,6 +79,8 @@ class Peer(Logger, EventListener):
|
||||
'query_short_channel_ids', 'reply_short_channel_ids', 'reply_short_channel_ids_end')
|
||||
|
||||
DELAY_INC_MSG_PROCESSING_SLEEP = 0.01
|
||||
RECV_GOSSIP_QUEUE_SOFT_MAXSIZE = 2000
|
||||
RECV_GOSSIP_QUEUE_HARD_MAXSIZE = 5000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -104,12 +107,13 @@ class Peer(Logger, EventListener):
|
||||
assert self.node_ids[0] != self.node_ids[1]
|
||||
self.last_message_time = 0
|
||||
self.pong_event = asyncio.Event()
|
||||
self.reply_channel_range = asyncio.Queue()
|
||||
self.reply_channel_range = None # type: Optional[asyncio.Queue]
|
||||
# gossip uses a single queue to preserve message order
|
||||
self.gossip_queue = asyncio.Queue()
|
||||
self.gossip_timestamp_filter = None # type: Optional[GossipTimestampFilter]
|
||||
self.recv_gossip_queue = asyncio.Queue(maxsize=self.RECV_GOSSIP_QUEUE_HARD_MAXSIZE)
|
||||
self.our_gossip_timestamp_filter = None # type: Optional[GossipTimestampFilter]
|
||||
self.their_gossip_timestamp_filter = None # type: Optional[GossipTimestampFilter]
|
||||
self.outgoing_gossip_reply = False # type: bool
|
||||
self.ordered_message_queues = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue] # for messages that are ordered
|
||||
self.ordered_message_queues = defaultdict(partial(asyncio.Queue, maxsize=10)) # type: Dict[bytes, asyncio.Queue] # for messages that are ordered
|
||||
self.temp_id_to_id = {} # type: Dict[bytes, Optional[bytes]] # to forward error messages
|
||||
self.funding_created_sent = set() # for channels in PREOPENING
|
||||
self.funding_signed_sent = set() # for channels in PREOPENING
|
||||
@@ -225,6 +229,12 @@ class Peer(Logger, EventListener):
|
||||
return
|
||||
if message_type in self.ORDERED_MESSAGES:
|
||||
chan_id = payload.get('channel_id') or payload["temporary_channel_id"]
|
||||
if (
|
||||
chan_id not in self.channels
|
||||
and chan_id not in self.temp_id_to_id
|
||||
and chan_id not in self.temp_id_to_id.values()
|
||||
):
|
||||
raise Exception(f"received {message_type} for unknown {chan_id.hex()=}")
|
||||
self.ordered_message_queues[chan_id].put_nowait((message_type, payload))
|
||||
else:
|
||||
if message_type not in ('error', 'warning') and 'channel_id' in payload:
|
||||
@@ -399,17 +409,26 @@ class Peer(Logger, EventListener):
|
||||
self.maybe_set_initialized()
|
||||
|
||||
def on_node_announcement(self, payload):
|
||||
if not self.lnworker.uses_trampoline():
|
||||
self.gossip_queue.put_nowait(('node_announcement', payload))
|
||||
if self.lnworker.uses_trampoline():
|
||||
return
|
||||
if self.our_gossip_timestamp_filter is None:
|
||||
return # why is the peer sending this? should we disconnect?
|
||||
self.recv_gossip_queue.put_nowait(('node_announcement', payload))
|
||||
|
||||
def on_channel_announcement(self, payload):
|
||||
if not self.lnworker.uses_trampoline():
|
||||
self.gossip_queue.put_nowait(('channel_announcement', payload))
|
||||
if self.lnworker.uses_trampoline():
|
||||
return
|
||||
if self.our_gossip_timestamp_filter is None:
|
||||
return # why is the peer sending this? should we disconnect?
|
||||
self.recv_gossip_queue.put_nowait(('channel_announcement', payload))
|
||||
|
||||
def on_channel_update(self, payload):
|
||||
self.maybe_save_remote_update(payload)
|
||||
if not self.lnworker.uses_trampoline():
|
||||
self.gossip_queue.put_nowait(('channel_update', payload))
|
||||
if self.lnworker.uses_trampoline():
|
||||
return
|
||||
if self.our_gossip_timestamp_filter is None:
|
||||
return # why is the peer sending this? should we disconnect?
|
||||
self.recv_gossip_queue.put_nowait(('channel_update', payload))
|
||||
|
||||
def on_query_channel_range(self, payload):
|
||||
if self.lnworker == self.lnworker.network.lngossip or not self._should_forward_gossip():
|
||||
@@ -419,7 +438,7 @@ class Peer(Logger, EventListener):
|
||||
if self.outgoing_gossip_reply:
|
||||
return self.send_warning(bytes(32), "received multiple queries at the same time")
|
||||
self.outgoing_gossip_reply = True
|
||||
self.gossip_queue.put_nowait(('query_channel_range', payload))
|
||||
self.recv_gossip_queue.put_nowait(('query_channel_range', payload))
|
||||
|
||||
def on_query_short_channel_ids(self, payload):
|
||||
if self.lnworker == self.lnworker.network.lngossip or not self._should_forward_gossip():
|
||||
@@ -429,7 +448,7 @@ class Peer(Logger, EventListener):
|
||||
if not self._is_valid_short_channel_id_query(payload):
|
||||
return self.send_warning(bytes(32), "invalid query_short_channel_ids")
|
||||
self.outgoing_gossip_reply = True
|
||||
self.gossip_queue.put_nowait(('query_short_channel_ids', payload))
|
||||
self.recv_gossip_queue.put_nowait(('query_short_channel_ids', payload))
|
||||
|
||||
def on_gossip_timestamp_filter(self, payload):
|
||||
if self._should_forward_gossip():
|
||||
@@ -441,11 +460,11 @@ class Peer(Logger, EventListener):
|
||||
if payload.get('chain_hash') != constants.net.rev_genesis_bytes():
|
||||
return
|
||||
filter = GossipTimestampFilter.from_payload(payload)
|
||||
self.gossip_timestamp_filter = filter
|
||||
self.their_gossip_timestamp_filter = filter
|
||||
self.logger.debug(f"got gossip_ts_filter from peer {self.pubkey.hex()}: "
|
||||
f"{str(self.gossip_timestamp_filter)}")
|
||||
f"{str(self.their_gossip_timestamp_filter)}")
|
||||
if filter and not filter.only_forwarding:
|
||||
self.gossip_queue.put_nowait(('gossip_timestamp_filter', None))
|
||||
self.recv_gossip_queue.put_nowait(('gossip_timestamp_filter', None))
|
||||
|
||||
def maybe_save_remote_update(self, payload):
|
||||
if not self.channels:
|
||||
@@ -521,7 +540,7 @@ class Peer(Logger, EventListener):
|
||||
chan_upds = []
|
||||
node_anns = []
|
||||
while True:
|
||||
name, payload = await self.gossip_queue.get()
|
||||
name, payload = await self.recv_gossip_queue.get()
|
||||
if name == 'channel_announcement':
|
||||
chan_anns.append(payload)
|
||||
elif name == 'channel_update':
|
||||
@@ -536,7 +555,7 @@ class Peer(Logger, EventListener):
|
||||
await self.taskgroup.spawn(self._handle_historical_gossip_request())
|
||||
else:
|
||||
raise Exception('unknown message')
|
||||
if self.gossip_queue.empty():
|
||||
if self.recv_gossip_queue.empty():
|
||||
break
|
||||
if self.network.lngossip:
|
||||
await self.network.lngossip.process_gossip(chan_anns, node_anns, chan_upds)
|
||||
@@ -577,7 +596,7 @@ class Peer(Logger, EventListener):
|
||||
last_gossip_batch_ts = 0
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
if not self.gossip_timestamp_filter:
|
||||
if not self.their_gossip_timestamp_filter:
|
||||
continue # peer didn't request gossip
|
||||
|
||||
new_gossip, last_lngossip_refresh_ts = await lngossip.get_forwarding_gossip()
|
||||
@@ -589,7 +608,7 @@ class Peer(Logger, EventListener):
|
||||
|
||||
async def _handle_historical_gossip_request(self):
|
||||
"""Called when a peer requests historical gossip with a gossip_timestamp_filter query."""
|
||||
filter = self.gossip_timestamp_filter
|
||||
filter = self.their_gossip_timestamp_filter
|
||||
if not self._should_forward_gossip() or not filter or filter.only_forwarding:
|
||||
return
|
||||
async with self.network.lngossip.gossip_request_semaphore:
|
||||
@@ -603,7 +622,7 @@ class Peer(Logger, EventListener):
|
||||
async def _send_gossip_messages(self, messages: List[GossipForwardingMessage]) -> int:
|
||||
amount_sent = 0
|
||||
for msg in messages:
|
||||
if self.gossip_timestamp_filter.in_range(msg.timestamp) \
|
||||
if self.their_gossip_timestamp_filter.in_range(msg.timestamp) \
|
||||
and self.pubkey != msg.sender_node_id:
|
||||
await self.transport.send_bytes_and_drain(msg.msg)
|
||||
amount_sent += 1
|
||||
@@ -697,6 +716,7 @@ class Peer(Logger, EventListener):
|
||||
self.outgoing_gossip_reply = False
|
||||
|
||||
async def get_channel_range(self):
|
||||
self.reply_channel_range = asyncio.Queue()
|
||||
first_block = constants.net.BLOCK_HEIGHT_FIRST_LIGHTNING_CHANNELS
|
||||
num_blocks = self.lnworker.network.get_local_height() - first_block
|
||||
self.query_channel_range(first_block, num_blocks)
|
||||
@@ -735,6 +755,7 @@ class Peer(Logger, EventListener):
|
||||
a, b = intervals[0]
|
||||
if a <= first_block and b >= first_block + num_blocks:
|
||||
break
|
||||
self.reply_channel_range = None
|
||||
return ids, complete
|
||||
|
||||
def request_gossip(self, timestamp=0):
|
||||
@@ -742,11 +763,17 @@ class Peer(Logger, EventListener):
|
||||
self.logger.info('requesting whole channel graph')
|
||||
else:
|
||||
self.logger.info(f'requesting channel graph since {datetime.fromtimestamp(timestamp).isoformat()}')
|
||||
timestamp_range = 0xFFFFFFFF
|
||||
self.our_gossip_timestamp_filter = GossipTimestampFilter(
|
||||
first_timestamp=timestamp,
|
||||
timestamp_range=timestamp_range,
|
||||
)
|
||||
self.send_message(
|
||||
'gossip_timestamp_filter',
|
||||
chain_hash=constants.net.rev_genesis_bytes(),
|
||||
first_timestamp=timestamp,
|
||||
timestamp_range=b'\xff'*4)
|
||||
timestamp_range=timestamp_range,
|
||||
)
|
||||
|
||||
def query_channel_range(self, first_block, num_blocks):
|
||||
self.logger.info(f'query channel range {first_block} {num_blocks}')
|
||||
@@ -766,7 +793,7 @@ class Peer(Logger, EventListener):
|
||||
ids = [decoded[i:i+8] for i in range(0, len(decoded), 8)]
|
||||
return ids
|
||||
|
||||
def on_reply_channel_range(self, payload):
|
||||
async def on_reply_channel_range(self, payload):
|
||||
first = payload['first_blocknum']
|
||||
num = payload['number_of_blocks']
|
||||
complete = bool(int.from_bytes(payload['sync_complete'], 'big'))
|
||||
@@ -774,6 +801,12 @@ class Peer(Logger, EventListener):
|
||||
ids = self.decode_short_ids(encoded)
|
||||
# self.logger.info(f"on_reply_channel_range. >>> first_block {first}, num_blocks {num}, "
|
||||
# f"num_ids {len(ids)}, complete {complete}")
|
||||
if self.reply_channel_range is None:
|
||||
raise Exception("received 'reply_channel_range' without corresponding 'query_channel_range'")
|
||||
while self.reply_channel_range.qsize() > 10:
|
||||
# we block process_message until the queue gets consumed
|
||||
self.logger.info("reply_channel_range queue is overflowing. sleeping...")
|
||||
await asyncio.sleep(0.1)
|
||||
self.reply_channel_range.put_nowait((first, num, complete, ids))
|
||||
|
||||
async def _send_reply_short_channel_ids(self, payload: dict):
|
||||
@@ -830,6 +863,15 @@ class Peer(Logger, EventListener):
|
||||
# rate-limit message-processing a bit, to make it harder
|
||||
# for a single peer to bog down the event loop / cpu:
|
||||
await asyncio.sleep(self.DELAY_INC_MSG_PROCESSING_SLEEP)
|
||||
# If receiving too much gossip from this peer, we need to slow them down.
|
||||
# note: if the gossip queue gets full, we will disconnect from them
|
||||
# and throw away unprocessed gossip.
|
||||
if self.recv_gossip_queue.qsize() > self.RECV_GOSSIP_QUEUE_SOFT_MAXSIZE:
|
||||
sleep = self.recv_gossip_queue.qsize() / 1000
|
||||
self.logger.debug(
|
||||
f"message_loop sleeping due to getting much gossip. qsize={self.recv_gossip_queue.qsize()}. "
|
||||
f"waiting for existing gossip data to be processed first.")
|
||||
await asyncio.sleep(sleep)
|
||||
|
||||
def on_reply_short_channel_ids_end(self, payload):
|
||||
self.querying.set()
|
||||
@@ -1050,8 +1092,8 @@ class Peer(Logger, EventListener):
|
||||
int.from_bytes(per_commitment_secret_first, 'big'))
|
||||
|
||||
# store the temp id now, so that it is recognized for e.g. 'error' messages
|
||||
# TODO: this is never cleaned up; the dict grows unbounded until disconnect
|
||||
self.temp_id_to_id[temp_channel_id] = None
|
||||
self._cleanup_temp_channelids()
|
||||
self.send_message(
|
||||
"open_channel",
|
||||
temporary_channel_id=temp_channel_id,
|
||||
@@ -1270,8 +1312,8 @@ class Peer(Logger, EventListener):
|
||||
feerate = payload['feerate_per_kw'] # note: we are not validating this
|
||||
temp_chan_id = payload['temporary_channel_id']
|
||||
# store the temp id now, so that it is recognized for e.g. 'error' messages
|
||||
# TODO: this is never cleaned up; the dict grows unbounded until disconnect
|
||||
self.temp_id_to_id[temp_chan_id] = None
|
||||
self._cleanup_temp_channelids()
|
||||
channel_opening_fee = open_channel_tlvs.get('channel_opening_fee') if open_channel_tlvs else None
|
||||
if channel_opening_fee:
|
||||
# todo check that the fee is reasonable
|
||||
@@ -1413,6 +1455,15 @@ class Peer(Logger, EventListener):
|
||||
self.send_channel_ready(chan)
|
||||
self.lnworker.add_new_channel(chan)
|
||||
|
||||
def _cleanup_temp_channelids(self) -> None:
|
||||
self.temp_id_to_id = {
|
||||
tmp_id: chan_id for (tmp_id, chan_id) in self.temp_id_to_id.items()
|
||||
if chan_id not in self.channels
|
||||
}
|
||||
if len(self.temp_id_to_id) > 25:
|
||||
# which one of us is opening all these chans?! let's disconnect
|
||||
raise Exception("temp_id_to_id is getting too large.")
|
||||
|
||||
async def request_force_close(self, channel_id: bytes):
|
||||
"""Try to trigger the remote peer to force-close."""
|
||||
await self.initialized
|
||||
|
||||
@@ -202,6 +202,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
self.lock = threading.RLock()
|
||||
self.node_keypair = node_keypair
|
||||
self._peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer # needs self.lock
|
||||
self._channelless_incoming_peers = set() # type: Set[bytes] # node_ids # needs self.lock
|
||||
self.taskgroup = OldTaskGroup()
|
||||
self.listen_server = None # type: Optional[asyncio.AbstractServer]
|
||||
self.features = features
|
||||
@@ -252,13 +253,15 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
return
|
||||
addr = str(netaddr.host)
|
||||
|
||||
async def cb(reader, writer):
|
||||
async def cb(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
transport = LNResponderTransport(self.node_keypair.privkey, reader, writer)
|
||||
try:
|
||||
node_id = await transport.handshake()
|
||||
except Exception as e:
|
||||
self.logger.info(f'handshake failure from incoming connection: {e!r}')
|
||||
return
|
||||
peername = writer.get_extra_info('peername')
|
||||
self.logger.debug(f"handshake done for incoming peer: {peername=}, node_id={node_id.hex()}")
|
||||
await self._add_peer_from_transport(node_id=node_id, transport=transport)
|
||||
try:
|
||||
self.listen_server = await asyncio.start_server(cb, addr, netaddr.port)
|
||||
@@ -315,11 +318,29 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
# both keep trying to reconnect, resulting in neither being usable.
|
||||
if existing_peer.is_initialized():
|
||||
# give priority to the existing connection
|
||||
return
|
||||
transport.close()
|
||||
return None
|
||||
else:
|
||||
# Use the new connection. (e.g. old peer might be an outgoing connection
|
||||
# for an outdated host/port that will never connect)
|
||||
existing_peer.close_and_cleanup()
|
||||
# limit max number of incoming channel-less peers.
|
||||
# what to do if limit is reached?
|
||||
# - chosen strategy: we don't allow new connections.
|
||||
# - drawback: attacker can use up all our slots
|
||||
# - alternative: kick oldest channel-less peer
|
||||
# - drawback: if many legit peers want to connect to us, we will keep kicking them
|
||||
# in round-robin, and they will keep reconnecting. no stable state -> we self-DOS
|
||||
# TODO make slots IP-based?
|
||||
if isinstance(transport, LNResponderTransport):
|
||||
assert node_id not in self._channelless_incoming_peers
|
||||
chans = [chan for chan in self.channels_for_peer(node_id).values() if chan.is_funded()]
|
||||
if not chans:
|
||||
if len(self._channelless_incoming_peers) > 100:
|
||||
transport.close()
|
||||
return None
|
||||
self._channelless_incoming_peers.add(node_id)
|
||||
# checks done: we are adding this peer.
|
||||
peer = Peer(self, node_id, transport)
|
||||
assert node_id not in self._peers
|
||||
self._peers[node_id] = peer
|
||||
@@ -331,6 +352,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
peer2 = self._peers.get(peer.pubkey)
|
||||
if peer2 is peer:
|
||||
self._peers.pop(peer.pubkey)
|
||||
self._channelless_incoming_peers.discard(peer.pubkey)
|
||||
|
||||
def num_peers(self) -> int:
|
||||
return sum([p.is_initialized() for p in self.peers.values()])
|
||||
|
||||
Reference in New Issue
Block a user