1
0

onion_messages_manager:

- use namedtuple instead of dict for pending messages
 - use asyncio.Future instead of event and result
This commit is contained in:
ThomasV
2025-01-18 14:01:02 +01:00
committed by Sander van Grieken
parent d814796484
commit 71b9761981

View File

@@ -28,7 +28,7 @@ import threading
import time
from random import random
from typing import TYPE_CHECKING, Optional, List, Sequence
from typing import TYPE_CHECKING, Optional, List, Sequence, NamedTuple
import electrum_ecc as ecc
@@ -40,7 +40,7 @@ from electrum.lnonion import (get_bolt04_onion_key, OnionPacket, process_onion_p
OnionHopsDataSingle, decrypt_onionmsg_data_tlv, encrypt_onionmsg_data_tlv,
get_shared_secrets_along_route, new_onion_packet)
from electrum.lnutil import LnFeatures
from electrum.util import OldTaskGroup
from electrum.util import OldTaskGroup, log_exceptions
# do not use util.now, because it rounds to integers
def now():
@@ -380,6 +380,11 @@ def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *,
class Timeout(Exception): pass
class OnionMessageRequest(NamedTuple):
future: asyncio.Future
payload: bytes
node_id_or_blinded_path: bytes
class OnionMessageManager(Logger):
"""handle state around onion message sends and receives
@@ -398,7 +403,6 @@ class OnionMessageManager(Logger):
self.network = None # type: Optional['Network']
self.taskgroup = None # type: OldTaskGroup
self.lnwallet = lnwallet
self.pending = {}
self.pending_lock = threading.Lock()
self.send_queue = asyncio.PriorityQueue()
@@ -411,16 +415,13 @@ class OnionMessageManager(Logger):
self.taskgroup = OldTaskGroup()
asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
@log_exceptions
async def main_loop(self):
self.logger.info("starting taskgroup.")
try:
async with self.taskgroup as group:
await group.spawn(self.process_send_queue())
await group.spawn(self.process_forward_queue())
except Exception as e:
self.logger.exception("taskgroup died.")
else:
self.logger.info("taskgroup stopped.")
async with self.taskgroup as group:
await group.spawn(self.process_send_queue())
await group.spawn(self.process_forward_queue())
self.logger.info("taskgroup stopped.")
async def stop(self):
await self.taskgroup.cancel_remaining()
@@ -466,16 +467,16 @@ class OnionMessageManager(Logger):
async def process_send_queue(self):
while True:
scheduled, expires, key = await self.send_queue.get()
requestreply = self.get_pending_message(key)
if requestreply is None:
req = self.pending.get(key)
if req is None:
self.logger.debug(f'no data for key {key=}')
continue
if requestreply.get('result') is not None:
if req.future.done():
self.logger.debug(f'has result! {key=}')
continue
if expires <= now():
self.logger.debug(f'expired {key=}')
self._set_message_result(key, Timeout())
req.future.set_exception(Timeout())
continue
if scheduled > now():
# return to queue
@@ -483,12 +484,11 @@ class OnionMessageManager(Logger):
self.send_queue.put_nowait((scheduled, expires, key))
await asyncio.sleep(self.SLEEP_DELAY) # sleep here, as the first queue item wasn't due yet
continue
try:
self._send_pending_message(key)
except BaseException as e:
self.logger.debug(f'error while sending {key=} {e!r}')
self._set_message_result(key, copy.copy(e))
req.future.set_exception(copy.copy(e))
# NOTE: above, when passing the caught exception instance e directly it leads to GeneratorExit() in
if isinstance(e, NoRouteFound) and e.peer_address:
await self.lnwallet.add_peer(str(e.peer_address))
@@ -496,26 +496,10 @@ class OnionMessageManager(Logger):
self.logger.debug(f'resubmit {key=}')
self.send_queue.put_nowait((now() + self.REQUEST_REPLY_RETRY_DELAY, expires, key))
def get_pending_message(self, key):
with self.pending_lock:
return self.pending.get(key)
def _set_message_result(self, key, result):
with self.pending_lock:
requestreply = self.pending.get(key)
if requestreply is None:
self.logger.error(f'requestreply with {key=} not found!')
return
self.pending[key]['result'] = result
requestreply['ev'].set()
def _remove_pending_message(self, key):
with self.pending_lock:
requestreply = self.pending.get(key)
if requestreply is None:
return
requestreply['ev'].set()
del self.pending[key]
if key in self.pending:
del self.pending[key]
def submit_send(
self, *,
@@ -535,48 +519,34 @@ class OnionMessageManager(Logger):
self.logger.debug(f'submit_send {key=} {payload=} {node_id_or_blinded_path=}')
req = OnionMessageRequest(
future=asyncio.Future(),
payload=payload,
node_id_or_blinded_path=node_id_or_blinded_path
)
with self.pending_lock:
if key in self.pending:
raise Exception(f'{key=} already exists!')
self.pending[key] = {
'ev': asyncio.Event(),
'payload': payload,
'node_id_or_blinded_path': node_id_or_blinded_path
}
self.pending[key] = req
# tuple = (when to process, when it expires, key)
expires = now() + self.REQUEST_REPLY_TIMEOUT
queueitem = (now(), expires, key)
self.send_queue.put_nowait(queueitem)
task = asyncio.create_task(self._wait_task(key))
task = asyncio.create_task(self._wait_task(key, req.future))
return task
async def _wait_task(self, key):
requestreply = self.get_pending_message(key)
assert requestreply
if requestreply is None:
return
async def _wait_task(self, key, future):
try:
self.logger.debug(f'wait task start {key}')
await requestreply['ev'].wait()
finally:
self.logger.debug(f'wait task end {key}')
try:
requestreply = self.get_pending_message(key)
assert requestreply
result = requestreply.get('result')
if isinstance(result, Exception):
raise result # raising in the task requires caller to explicitly extract exception.
return result
return await future
finally:
self._remove_pending_message(key)
def _send_pending_message(self, key):
"""adds reply_path to payload"""
data = self.get_pending_message(key)
payload = data.get('payload')
node_id_or_blinded_path = data.get('node_id_or_blinded_path')
req = self.pending.get(key)
payload = req.payload
node_id_or_blinded_path = req.node_id_or_blinded_path
self.logger.debug(f'send_pending_message {key=} {payload=} {node_id_or_blinded_path=}')
final_payload = copy.deepcopy(payload)
@@ -624,12 +594,11 @@ class OnionMessageManager(Logger):
self.logger.warning('not a reply to our request (unknown path_id prefix)')
return
key = correl_data[8:]
requestreply = self.get_pending_message(key)
if requestreply is None:
req = self.pending.get(key)
if req is None:
self.logger.warning('not a reply to our request (unknown request)')
return
self._set_message_result(key, (recipient_data, payload))
req.future.set_result((recipient_data, payload))
def on_onion_message_received_unsolicited(self, recipient_data, payload):
self.logger.debug('unsolicited onion_message received')