bolts: do not disconnect when receiving/sending "warning" messages
follow https://github.com/lightning/bolts/pull/1075
This commit is contained in:
@@ -102,7 +102,7 @@ class Peer(Logger):
|
|||||||
self.reply_channel_range = asyncio.Queue()
|
self.reply_channel_range = asyncio.Queue()
|
||||||
# gossip uses a single queue to preserve message order
|
# gossip uses a single queue to preserve message order
|
||||||
self.gossip_queue = asyncio.Queue()
|
self.gossip_queue = asyncio.Queue()
|
||||||
self.ordered_message_queues = defaultdict(asyncio.Queue) # for messages that are ordered
|
self.ordered_message_queues = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue] # for messages that are ordered
|
||||||
self.temp_id_to_id = {} # type: Dict[bytes, Optional[bytes]] # to forward error messages
|
self.temp_id_to_id = {} # type: Dict[bytes, Optional[bytes]] # to forward error messages
|
||||||
self.funding_created_sent = set() # for channels in PREOPENING
|
self.funding_created_sent = set() # for channels in PREOPENING
|
||||||
self.funding_signed_sent = set() # for channels in PREOPENING
|
self.funding_signed_sent = set() # for channels in PREOPENING
|
||||||
@@ -242,24 +242,18 @@ class Peer(Logger):
|
|||||||
def on_warning(self, payload):
|
def on_warning(self, payload):
|
||||||
chan_id = payload.get("channel_id")
|
chan_id = payload.get("channel_id")
|
||||||
err_bytes = payload['data']
|
err_bytes = payload['data']
|
||||||
|
is_known_chan_id = (chan_id in self.channels) or (chan_id in self.temp_id_to_id)
|
||||||
self.logger.info(f"remote peer sent warning [DO NOT TRUST THIS MESSAGE]: "
|
self.logger.info(f"remote peer sent warning [DO NOT TRUST THIS MESSAGE]: "
|
||||||
f"{error_text_bytes_to_safe_str(err_bytes)}. chan_id={chan_id.hex()}")
|
f"{error_text_bytes_to_safe_str(err_bytes)}. chan_id={chan_id.hex()}. "
|
||||||
if chan_id in self.channels:
|
f"{is_known_chan_id=}")
|
||||||
self.ordered_message_queues[chan_id].put_nowait((None, {'warning': err_bytes}))
|
|
||||||
elif chan_id in self.temp_id_to_id:
|
|
||||||
chan_id = self.temp_id_to_id[chan_id] or chan_id
|
|
||||||
self.ordered_message_queues[chan_id].put_nowait((None, {'warning': err_bytes}))
|
|
||||||
else:
|
|
||||||
# if no existing channel is referred to by channel_id:
|
|
||||||
# - MUST ignore the message.
|
|
||||||
return
|
|
||||||
raise GracefulDisconnect
|
|
||||||
|
|
||||||
def on_error(self, payload):
|
def on_error(self, payload):
|
||||||
chan_id = payload.get("channel_id")
|
chan_id = payload.get("channel_id")
|
||||||
err_bytes = payload['data']
|
err_bytes = payload['data']
|
||||||
|
is_known_chan_id = (chan_id in self.channels) or (chan_id in self.temp_id_to_id)
|
||||||
self.logger.info(f"remote peer sent error [DO NOT TRUST THIS MESSAGE]: "
|
self.logger.info(f"remote peer sent error [DO NOT TRUST THIS MESSAGE]: "
|
||||||
f"{error_text_bytes_to_safe_str(err_bytes)}. chan_id={chan_id.hex()}")
|
f"{error_text_bytes_to_safe_str(err_bytes)}. chan_id={chan_id.hex()}. "
|
||||||
|
f"{is_known_chan_id=}")
|
||||||
if chan_id in self.channels:
|
if chan_id in self.channels:
|
||||||
self.schedule_force_closing(chan_id)
|
self.schedule_force_closing(chan_id)
|
||||||
self.ordered_message_queues[chan_id].put_nowait((None, {'error': err_bytes}))
|
self.ordered_message_queues[chan_id].put_nowait((None, {'error': err_bytes}))
|
||||||
@@ -278,7 +272,7 @@ class Peer(Logger):
|
|||||||
return
|
return
|
||||||
raise GracefulDisconnect
|
raise GracefulDisconnect
|
||||||
|
|
||||||
async def send_warning(self, channel_id: bytes, message: str = None, *, close_connection=True):
|
async def send_warning(self, channel_id: bytes, message: str = None, *, close_connection=False):
|
||||||
"""Sends a warning and disconnects if close_connection.
|
"""Sends a warning and disconnects if close_connection.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
@@ -335,15 +329,14 @@ class Peer(Logger):
|
|||||||
def on_pong(self, payload):
|
def on_pong(self, payload):
|
||||||
self.pong_event.set()
|
self.pong_event.set()
|
||||||
|
|
||||||
async def wait_for_message(self, expected_name, channel_id):
|
async def wait_for_message(self, expected_name: str, channel_id: bytes):
|
||||||
q = self.ordered_message_queues[channel_id]
|
q = self.ordered_message_queues[channel_id]
|
||||||
name, payload = await asyncio.wait_for(q.get(), LN_P2P_NETWORK_TIMEOUT)
|
name, payload = await asyncio.wait_for(q.get(), LN_P2P_NETWORK_TIMEOUT)
|
||||||
# raise exceptions for errors/warnings, so that the caller sees them
|
# raise exceptions for errors, so that the caller sees them
|
||||||
if (err_bytes := (payload.get("error") or payload.get("warning"))) is not None:
|
if (err_bytes := payload.get("error")) is not None:
|
||||||
err_type = "error" if payload.get("error") else "warning"
|
|
||||||
err_text = error_text_bytes_to_safe_str(err_bytes)
|
err_text = error_text_bytes_to_safe_str(err_bytes)
|
||||||
raise GracefulDisconnect(
|
raise GracefulDisconnect(
|
||||||
f"remote peer sent {err_type} [DO NOT TRUST THIS MESSAGE]: {err_text}")
|
f"remote peer sent error [DO NOT TRUST THIS MESSAGE]: {err_text}")
|
||||||
if name != expected_name:
|
if name != expected_name:
|
||||||
raise Exception(f"Received unexpected '{name}'")
|
raise Exception(f"Received unexpected '{name}'")
|
||||||
return payload
|
return payload
|
||||||
|
|||||||
Reference in New Issue
Block a user