From 71b97619812ee510a5353112fad93e00b98495e2 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Sat, 18 Jan 2025 14:01:02 +0100 Subject: [PATCH] onion_messages_manager: - use namedtuple instead of dict for pending messages - use asyncio.Future instead of event and result --- electrum/onion_message.py | 99 ++++++++++++++------------------------- 1 file changed, 34 insertions(+), 65 deletions(-) diff --git a/electrum/onion_message.py b/electrum/onion_message.py index 64f89927f..c9ab88a53 100644 --- a/electrum/onion_message.py +++ b/electrum/onion_message.py @@ -28,7 +28,7 @@ import threading import time from random import random -from typing import TYPE_CHECKING, Optional, List, Sequence +from typing import TYPE_CHECKING, Optional, List, Sequence, NamedTuple import electrum_ecc as ecc @@ -40,7 +40,7 @@ from electrum.lnonion import (get_bolt04_onion_key, OnionPacket, process_onion_p OnionHopsDataSingle, decrypt_onionmsg_data_tlv, encrypt_onionmsg_data_tlv, get_shared_secrets_along_route, new_onion_packet) from electrum.lnutil import LnFeatures -from electrum.util import OldTaskGroup +from electrum.util import OldTaskGroup, log_exceptions # do not use util.now, because it rounds to integers def now(): @@ -380,6 +380,11 @@ def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *, class Timeout(Exception): pass +class OnionMessageRequest(NamedTuple): + future: asyncio.Future + payload: bytes + node_id_or_blinded_path: bytes + class OnionMessageManager(Logger): """handle state around onion message sends and receives @@ -398,7 +403,6 @@ class OnionMessageManager(Logger): self.network = None # type: Optional['Network'] self.taskgroup = None # type: OldTaskGroup self.lnwallet = lnwallet - self.pending = {} self.pending_lock = threading.Lock() self.send_queue = asyncio.PriorityQueue() @@ -411,16 +415,13 @@ class OnionMessageManager(Logger): self.taskgroup = OldTaskGroup() asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop) + @log_exceptions async def main_loop(self): self.logger.info("starting taskgroup.") - try: - async with self.taskgroup as group: - await group.spawn(self.process_send_queue()) - await group.spawn(self.process_forward_queue()) - except Exception as e: - self.logger.exception("taskgroup died.") - else: - self.logger.info("taskgroup stopped.") + async with self.taskgroup as group: + await group.spawn(self.process_send_queue()) + await group.spawn(self.process_forward_queue()) + self.logger.info("taskgroup stopped.") async def stop(self): await self.taskgroup.cancel_remaining() @@ -466,16 +467,16 @@ class OnionMessageManager(Logger): async def process_send_queue(self): while True: scheduled, expires, key = await self.send_queue.get() - requestreply = self.get_pending_message(key) - if requestreply is None: + req = self.pending.get(key) + if req is None: self.logger.debug(f'no data for key {key=}') continue - if requestreply.get('result') is not None: + if req.future.done(): self.logger.debug(f'has result! {key=}') continue if expires <= now(): self.logger.debug(f'expired {key=}') - self._set_message_result(key, Timeout()) + req.future.set_exception(Timeout()) continue if scheduled > now(): # return to queue @@ -483,12 +484,11 @@ class OnionMessageManager(Logger): self.send_queue.put_nowait((scheduled, expires, key)) await asyncio.sleep(self.SLEEP_DELAY) # sleep here, as the first queue item wasn't due yet continue - try: self._send_pending_message(key) except BaseException as e: self.logger.debug(f'error while sending {key=} {e!r}') - self._set_message_result(key, copy.copy(e)) + req.future.set_exception(copy.copy(e)) # NOTE: above, when passing the caught exception instance e directly it leads to GeneratorExit() in if isinstance(e, NoRouteFound) and e.peer_address: await self.lnwallet.add_peer(str(e.peer_address)) @@ -496,26 +496,10 @@ class OnionMessageManager(Logger): self.logger.debug(f'resubmit {key=}') self.send_queue.put_nowait((now() + self.REQUEST_REPLY_RETRY_DELAY, expires, key)) - def get_pending_message(self, key): - with self.pending_lock: - return self.pending.get(key) - - def _set_message_result(self, key, result): - with self.pending_lock: - requestreply = self.pending.get(key) - if requestreply is None: - self.logger.error(f'requestreply with {key=} not found!') - return - self.pending[key]['result'] = result - requestreply['ev'].set() - def _remove_pending_message(self, key): with self.pending_lock: - requestreply = self.pending.get(key) - if requestreply is None: - return - requestreply['ev'].set() - del self.pending[key] + if key in self.pending: + del self.pending[key] def submit_send( self, *, @@ -535,48 +519,34 @@ class OnionMessageManager(Logger): self.logger.debug(f'submit_send {key=} {payload=} {node_id_or_blinded_path=}') + req = OnionMessageRequest( + future=asyncio.Future(), + payload=payload, + node_id_or_blinded_path=node_id_or_blinded_path + ) with self.pending_lock: if key in self.pending: raise Exception(f'{key=} already exists!') - self.pending[key] = { - 'ev': asyncio.Event(), - 'payload': payload, - 'node_id_or_blinded_path': node_id_or_blinded_path - } + self.pending[key] = req # tuple = (when to process, when it expires, key) expires = now() + self.REQUEST_REPLY_TIMEOUT queueitem = (now(), expires, key) self.send_queue.put_nowait(queueitem) - task = asyncio.create_task(self._wait_task(key)) + task = asyncio.create_task(self._wait_task(key, req.future)) return task - async def _wait_task(self, key): - requestreply = self.get_pending_message(key) - assert requestreply - if requestreply is None: - return + async def _wait_task(self, key, future): try: - self.logger.debug(f'wait task start {key}') - await requestreply['ev'].wait() - finally: - self.logger.debug(f'wait task end {key}') - - try: - requestreply = self.get_pending_message(key) - assert requestreply - result = requestreply.get('result') - if isinstance(result, Exception): - raise result # raising in the task requires caller to explicitly extract exception. - return result + return await future finally: self._remove_pending_message(key) def _send_pending_message(self, key): """adds reply_path to payload""" - data = self.get_pending_message(key) - payload = data.get('payload') - node_id_or_blinded_path = data.get('node_id_or_blinded_path') + req = self.pending.get(key) + payload = req.payload + node_id_or_blinded_path = req.node_id_or_blinded_path self.logger.debug(f'send_pending_message {key=} {payload=} {node_id_or_blinded_path=}') final_payload = copy.deepcopy(payload) @@ -624,12 +594,11 @@ class OnionMessageManager(Logger): self.logger.warning('not a reply to our request (unknown path_id prefix)') return key = correl_data[8:] - requestreply = self.get_pending_message(key) - if requestreply is None: + req = self.pending.get(key) + if req is None: self.logger.warning('not a reply to our request (unknown request)') return - - self._set_message_result(key, (recipient_data, payload)) + req.future.set_result((recipient_data, payload)) def on_onion_message_received_unsolicited(self, recipient_data, payload): self.logger.debug('unsolicited onion_message received')