lnpeer: factorize on_warning/on_error code
This commit is contained in:
@@ -224,44 +224,35 @@ class Peer(Logger):
|
||||
if asyncio.iscoroutinefunction(f):
|
||||
asyncio.ensure_future(self.taskgroup.spawn(execution_result))
|
||||
|
||||
def _get_channel_ids(self, channel_id):
|
||||
# if channel_id is all zero: MUST fail all channels with the sending node.
|
||||
# otherwise: MUST fail the channel referred to by channel_id, if that channel is with the sending node.
|
||||
# if no existing channel is referred to by `channel_id: MUST ignore the message.
|
||||
if channel_id == bytes(32):
|
||||
return self.channels.keys()
|
||||
elif channel_id in self.temp_id_to_id:
|
||||
return [self.temp_id_to_id[channel_id]]
|
||||
elif channel_id in self.channels:
|
||||
return [channel_id]
|
||||
else:
|
||||
return []
|
||||
|
||||
def on_warning(self, payload):
|
||||
# TODO: we could need some reconnection logic here -> delayed reconnect
|
||||
self.logger.info(f"remote peer sent warning [DO NOT TRUST THIS MESSAGE]: {payload['data'].decode('ascii')}")
|
||||
channel_id = payload.get("channel_id")
|
||||
if channel_id == bytes(32):
|
||||
for cid in self.channels.keys():
|
||||
self.ordered_message_queues[cid].put_nowait((None, {'warning': payload['data']}))
|
||||
raise GracefulDisconnect
|
||||
warned_channel_id = None
|
||||
if channel_id in self.temp_id_to_id:
|
||||
warned_channel_id = self.temp_id_to_id[channel_id]
|
||||
elif channel_id in self.channels:
|
||||
warned_channel_id = channel_id
|
||||
if warned_channel_id:
|
||||
# MAY disconnect.
|
||||
self.ordered_message_queues[warned_channel_id].put_nowait((None, {'warning': payload['data']}))
|
||||
channel_ids = self._get_channel_ids(payload.get("channel_id"))
|
||||
for cid in channel_ids:
|
||||
self.ordered_message_queues[cid].put_nowait((None, {'warning': payload['data']}))
|
||||
if channel_ids:
|
||||
raise GracefulDisconnect
|
||||
|
||||
def on_error(self, payload):
|
||||
self.logger.info(f"remote peer sent error [DO NOT TRUST THIS MESSAGE]: {payload['data'].decode('ascii')}")
|
||||
channel_id = payload.get("channel_id")
|
||||
# if channel_id is all zero: MUST fail all channels with the sending node.
|
||||
if channel_id == bytes(32):
|
||||
for cid in self.channels.keys():
|
||||
self.schedule_force_closing(cid)
|
||||
self.ordered_message_queues[cid].put_nowait((None, {'error': payload['data']}))
|
||||
raise GracefulDisconnect
|
||||
# otherwise: MUST fail the channel referred to by channel_id, if that channel is with the sending node.
|
||||
erring_channel_id = None
|
||||
if channel_id in self.temp_id_to_id:
|
||||
erring_channel_id = self.temp_id_to_id[channel_id]
|
||||
elif channel_id in self.channels:
|
||||
erring_channel_id = channel_id
|
||||
if erring_channel_id:
|
||||
self.schedule_force_closing(erring_channel_id)
|
||||
self.ordered_message_queues[erring_channel_id].put_nowait((None, {'error': payload['data']}))
|
||||
# disconnect now as there might be no one waiting on the queue...
|
||||
# OTOH this means if there are waiters, they might not see the error
|
||||
channel_ids = self._get_channel_ids(payload.get("channel_id"))
|
||||
for cid in channel_ids:
|
||||
self.schedule_force_closing(cid)
|
||||
self.ordered_message_queues[cid].put_nowait((None, {'error': payload['data']}))
|
||||
if channel_ids:
|
||||
raise GracefulDisconnect
|
||||
|
||||
async def send_warning(self, channel_id: bytes, message: str = None, *, close_connection=True):
|
||||
|
||||
Reference in New Issue
Block a user