onion_messages_manager:
- use namedtuple instead of dict for pending messages - use asyncio.Future instead of event and result
This commit is contained in:
committed by
Sander van Grieken
parent
d814796484
commit
71b9761981
@@ -28,7 +28,7 @@ import threading
|
||||
import time
|
||||
from random import random
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, List, Sequence
|
||||
from typing import TYPE_CHECKING, Optional, List, Sequence, NamedTuple
|
||||
|
||||
import electrum_ecc as ecc
|
||||
|
||||
@@ -40,7 +40,7 @@ from electrum.lnonion import (get_bolt04_onion_key, OnionPacket, process_onion_p
|
||||
OnionHopsDataSingle, decrypt_onionmsg_data_tlv, encrypt_onionmsg_data_tlv,
|
||||
get_shared_secrets_along_route, new_onion_packet)
|
||||
from electrum.lnutil import LnFeatures
|
||||
from electrum.util import OldTaskGroup
|
||||
from electrum.util import OldTaskGroup, log_exceptions
|
||||
|
||||
# do not use util.now, because it rounds to integers
|
||||
def now():
|
||||
@@ -380,6 +380,11 @@ def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *,
|
||||
|
||||
class Timeout(Exception): pass
|
||||
|
||||
class OnionMessageRequest(NamedTuple):
|
||||
future: asyncio.Future
|
||||
payload: bytes
|
||||
node_id_or_blinded_path: bytes
|
||||
|
||||
|
||||
class OnionMessageManager(Logger):
|
||||
"""handle state around onion message sends and receives
|
||||
@@ -398,7 +403,6 @@ class OnionMessageManager(Logger):
|
||||
self.network = None # type: Optional['Network']
|
||||
self.taskgroup = None # type: OldTaskGroup
|
||||
self.lnwallet = lnwallet
|
||||
|
||||
self.pending = {}
|
||||
self.pending_lock = threading.Lock()
|
||||
self.send_queue = asyncio.PriorityQueue()
|
||||
@@ -411,16 +415,13 @@ class OnionMessageManager(Logger):
|
||||
self.taskgroup = OldTaskGroup()
|
||||
asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
|
||||
|
||||
@log_exceptions
|
||||
async def main_loop(self):
|
||||
self.logger.info("starting taskgroup.")
|
||||
try:
|
||||
async with self.taskgroup as group:
|
||||
await group.spawn(self.process_send_queue())
|
||||
await group.spawn(self.process_forward_queue())
|
||||
except Exception as e:
|
||||
self.logger.exception("taskgroup died.")
|
||||
else:
|
||||
self.logger.info("taskgroup stopped.")
|
||||
async with self.taskgroup as group:
|
||||
await group.spawn(self.process_send_queue())
|
||||
await group.spawn(self.process_forward_queue())
|
||||
self.logger.info("taskgroup stopped.")
|
||||
|
||||
async def stop(self):
|
||||
await self.taskgroup.cancel_remaining()
|
||||
@@ -466,16 +467,16 @@ class OnionMessageManager(Logger):
|
||||
async def process_send_queue(self):
|
||||
while True:
|
||||
scheduled, expires, key = await self.send_queue.get()
|
||||
requestreply = self.get_pending_message(key)
|
||||
if requestreply is None:
|
||||
req = self.pending.get(key)
|
||||
if req is None:
|
||||
self.logger.debug(f'no data for key {key=}')
|
||||
continue
|
||||
if requestreply.get('result') is not None:
|
||||
if req.future.done():
|
||||
self.logger.debug(f'has result! {key=}')
|
||||
continue
|
||||
if expires <= now():
|
||||
self.logger.debug(f'expired {key=}')
|
||||
self._set_message_result(key, Timeout())
|
||||
req.future.set_exception(Timeout())
|
||||
continue
|
||||
if scheduled > now():
|
||||
# return to queue
|
||||
@@ -483,12 +484,11 @@ class OnionMessageManager(Logger):
|
||||
self.send_queue.put_nowait((scheduled, expires, key))
|
||||
await asyncio.sleep(self.SLEEP_DELAY) # sleep here, as the first queue item wasn't due yet
|
||||
continue
|
||||
|
||||
try:
|
||||
self._send_pending_message(key)
|
||||
except BaseException as e:
|
||||
self.logger.debug(f'error while sending {key=} {e!r}')
|
||||
self._set_message_result(key, copy.copy(e))
|
||||
req.future.set_exception(copy.copy(e))
|
||||
# NOTE: above, when passing the caught exception instance e directly it leads to GeneratorExit() in
|
||||
if isinstance(e, NoRouteFound) and e.peer_address:
|
||||
await self.lnwallet.add_peer(str(e.peer_address))
|
||||
@@ -496,26 +496,10 @@ class OnionMessageManager(Logger):
|
||||
self.logger.debug(f'resubmit {key=}')
|
||||
self.send_queue.put_nowait((now() + self.REQUEST_REPLY_RETRY_DELAY, expires, key))
|
||||
|
||||
def get_pending_message(self, key):
|
||||
with self.pending_lock:
|
||||
return self.pending.get(key)
|
||||
|
||||
def _set_message_result(self, key, result):
|
||||
with self.pending_lock:
|
||||
requestreply = self.pending.get(key)
|
||||
if requestreply is None:
|
||||
self.logger.error(f'requestreply with {key=} not found!')
|
||||
return
|
||||
self.pending[key]['result'] = result
|
||||
requestreply['ev'].set()
|
||||
|
||||
def _remove_pending_message(self, key):
|
||||
with self.pending_lock:
|
||||
requestreply = self.pending.get(key)
|
||||
if requestreply is None:
|
||||
return
|
||||
requestreply['ev'].set()
|
||||
del self.pending[key]
|
||||
if key in self.pending:
|
||||
del self.pending[key]
|
||||
|
||||
def submit_send(
|
||||
self, *,
|
||||
@@ -535,48 +519,34 @@ class OnionMessageManager(Logger):
|
||||
|
||||
self.logger.debug(f'submit_send {key=} {payload=} {node_id_or_blinded_path=}')
|
||||
|
||||
req = OnionMessageRequest(
|
||||
future=asyncio.Future(),
|
||||
payload=payload,
|
||||
node_id_or_blinded_path=node_id_or_blinded_path
|
||||
)
|
||||
with self.pending_lock:
|
||||
if key in self.pending:
|
||||
raise Exception(f'{key=} already exists!')
|
||||
self.pending[key] = {
|
||||
'ev': asyncio.Event(),
|
||||
'payload': payload,
|
||||
'node_id_or_blinded_path': node_id_or_blinded_path
|
||||
}
|
||||
self.pending[key] = req
|
||||
|
||||
# tuple = (when to process, when it expires, key)
|
||||
expires = now() + self.REQUEST_REPLY_TIMEOUT
|
||||
queueitem = (now(), expires, key)
|
||||
self.send_queue.put_nowait(queueitem)
|
||||
task = asyncio.create_task(self._wait_task(key))
|
||||
task = asyncio.create_task(self._wait_task(key, req.future))
|
||||
return task
|
||||
|
||||
async def _wait_task(self, key):
|
||||
requestreply = self.get_pending_message(key)
|
||||
assert requestreply
|
||||
if requestreply is None:
|
||||
return
|
||||
async def _wait_task(self, key, future):
|
||||
try:
|
||||
self.logger.debug(f'wait task start {key}')
|
||||
await requestreply['ev'].wait()
|
||||
finally:
|
||||
self.logger.debug(f'wait task end {key}')
|
||||
|
||||
try:
|
||||
requestreply = self.get_pending_message(key)
|
||||
assert requestreply
|
||||
result = requestreply.get('result')
|
||||
if isinstance(result, Exception):
|
||||
raise result # raising in the task requires caller to explicitly extract exception.
|
||||
return result
|
||||
return await future
|
||||
finally:
|
||||
self._remove_pending_message(key)
|
||||
|
||||
def _send_pending_message(self, key):
|
||||
"""adds reply_path to payload"""
|
||||
data = self.get_pending_message(key)
|
||||
payload = data.get('payload')
|
||||
node_id_or_blinded_path = data.get('node_id_or_blinded_path')
|
||||
req = self.pending.get(key)
|
||||
payload = req.payload
|
||||
node_id_or_blinded_path = req.node_id_or_blinded_path
|
||||
self.logger.debug(f'send_pending_message {key=} {payload=} {node_id_or_blinded_path=}')
|
||||
|
||||
final_payload = copy.deepcopy(payload)
|
||||
@@ -624,12 +594,11 @@ class OnionMessageManager(Logger):
|
||||
self.logger.warning('not a reply to our request (unknown path_id prefix)')
|
||||
return
|
||||
key = correl_data[8:]
|
||||
requestreply = self.get_pending_message(key)
|
||||
if requestreply is None:
|
||||
req = self.pending.get(key)
|
||||
if req is None:
|
||||
self.logger.warning('not a reply to our request (unknown request)')
|
||||
return
|
||||
|
||||
self._set_message_result(key, (recipient_data, payload))
|
||||
req.future.set_result((recipient_data, payload))
|
||||
|
||||
def on_onion_message_received_unsolicited(self, recipient_data, payload):
|
||||
self.logger.debug('unsolicited onion_message received')
|
||||
|
||||
Reference in New Issue
Block a user