From e216f1b324df8ee5fa792ac20b7456a54294400b Mon Sep 17 00:00:00 2001 From: Sander van Grieken Date: Wed, 12 Feb 2025 14:00:18 +0100 Subject: [PATCH] onion_messages: add parameter typing and comments --- electrum/onion_message.py | 105 +++++++++++++++++++++++++++----------- 1 file changed, 75 insertions(+), 30 deletions(-) diff --git a/electrum/onion_message.py b/electrum/onion_message.py index c9ab88a53..6e1aa322c 100644 --- a/electrum/onion_message.py +++ b/electrum/onion_message.py @@ -42,6 +42,7 @@ from electrum.lnonion import (get_bolt04_onion_key, OnionPacket, process_onion_p from electrum.lnutil import LnFeatures from electrum.util import OldTaskGroup, log_exceptions + # do not use util.now, because it rounds to integers def now(): return time.time() @@ -66,9 +67,14 @@ class NoRouteFound(Exception): self.peer_address = peer_address -def create_blinded_path(session_key: bytes, path: List[bytes], final_recipient_data: dict, *, - hop_extras: Optional[Sequence[dict]] = None, - dummy_hops: Optional[int] = 0) -> dict: +def create_blinded_path( + session_key: bytes, + path: List[bytes], + final_recipient_data: dict, + *, + hop_extras: Optional[Sequence[dict]] = None, + dummy_hops: Optional[int] = 0 +) -> dict: # dummy hops could be inserted anywhere in the path, but for compatibility just add them at the end # because blinded paths are usually constructed towards ourselves, and we know we can handle dummy hops. if dummy_hops: @@ -114,7 +120,7 @@ def create_blinded_path(session_key: bytes, path: List[bytes], final_recipient_d return blinded_path -def blinding_privkey(privkey, blinding): +def blinding_privkey(privkey: bytes, blinding: bytes) -> bytes: shared_secret = get_ecdh(privkey, blinding) b_hmac = get_bolt04_onion_key(b'blinded_node_id', shared_secret) b_hmac_int = int.from_bytes(b_hmac, byteorder="big") @@ -126,13 +132,13 @@ def blinding_privkey(privkey, blinding): return our_privkey -def is_onion_message_node(node_id: bytes, node_info: Optional['NodeInfo']): +def is_onion_message_node(node_id: bytes, node_info: Optional['NodeInfo']) -> bool: if not node_info: return False return LnFeatures(node_info.features).supports(LnFeatures.OPTION_ONION_MESSAGE_OPT) -def encrypt_onionmsg_tlv_hops_data(hops_data, hop_shared_secrets): +def encrypt_onionmsg_tlv_hops_data(hops_data: List[OnionHopsDataSingle], hop_shared_secrets: List[bytes]) -> None: """encrypt unencrypted onionmsg_tlv.encrypted_recipient_data for hops with blind_fields""" num_hops = len(hops_data) for i in range(num_hops): @@ -175,7 +181,12 @@ def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> List[ raise NoRouteFound('no path found') -def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes, destination_payload: dict, session_key: bytes = None): +def send_onion_message_to( + lnwallet: 'LNWallet', + node_id_or_blinded_path: bytes, + destination_payload: dict, + session_key: bytes = None +) -> None: if session_key is None: session_key = os.urandom(32) @@ -350,9 +361,19 @@ def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes, ) -def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *, - max_paths: int = REQUEST_REPLY_PATHS_MAX, - preferred_node_id: bytes = None) -> List[dict]: +def get_blinded_reply_paths( + lnwallet: 'LNWallet', + path_id: bytes, + *, + max_paths: int = REQUEST_REPLY_PATHS_MAX, + preferred_node_id: bytes = None +) -> List[dict]: + """construct a list of blinded reply_paths. + current logic: + - uses current onion_message capable channel peers if exist + - otherwise, uses current onion_message capable peers + - prefers preferred_node_id if given + - reply_path introduction points are direct peers only (TODO: longer reply paths)""" # TODO: build longer paths and/or add dummy hops to increase privacy my_active_channels = [chan for chan in lnwallet.channels.values() if chan.is_active()] my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.peers.get(chan.node_id) and @@ -361,7 +382,7 @@ def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *, result = [] mynodeid = lnwallet.node_keypair.pubkey - mydata = {'path_id': {'data': path_id}} # same used in every path + mydata = {'path_id': {'data': path_id}} # same path_id used in every reply path if len(my_onionmsg_channels): # randomize list, but prefer preferred_node_id rchans = sorted(my_onionmsg_channels, key=lambda x: random() if x.node_id != preferred_node_id else 0) @@ -380,6 +401,7 @@ def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *, class Timeout(Exception): pass + class OnionMessageRequest(NamedTuple): future: asyncio.Future payload: bytes @@ -387,9 +409,17 @@ class OnionMessageRequest(NamedTuple): class OnionMessageManager(Logger): - """handle state around onion message sends and receives + """handle state around onion message sends and receives. + - one instance per (ln)wallet - association between onion message and their replies - - manage re-send attempts, TODO: iterate through routes (both directions)""" + - manage re-send attempts while iterating over possible routes. Onion messages are unreliable + and fail silently if they don't reach their destination (or the reply gets dropped along the route back), + so the BOLT-4 spec suggests to send multiple messages, each with a different route to the introduction point). + - forwards are best-effort. They should not need retrying, but a queue is used to limit the pacing of forwarding, + and limiting the number of outstanding forwards. Any onion message forwards arriving when the forward queue + is full will be dropped. + + TODO: iterate through routes for each request""" SLEEP_DELAY = 1 REQUEST_REPLY_TIMEOUT = 30 @@ -403,12 +433,12 @@ class OnionMessageManager(Logger): self.network = None # type: Optional['Network'] self.taskgroup = None # type: OldTaskGroup self.lnwallet = lnwallet - self.pending = {} + self.pending = {} # type: dict[bytes, OnionMessageRequest] self.pending_lock = threading.Lock() self.send_queue = asyncio.PriorityQueue() self.forward_queue = asyncio.PriorityQueue() - def start_network(self, *, network: 'Network'): + def start_network(self, *, network: 'Network') -> None: assert network assert self.network is None, "already started" self.network = network @@ -416,17 +446,17 @@ class OnionMessageManager(Logger): asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop) @log_exceptions - async def main_loop(self): + async def main_loop(self) -> None: self.logger.info("starting taskgroup.") async with self.taskgroup as group: await group.spawn(self.process_send_queue()) await group.spawn(self.process_forward_queue()) self.logger.info("taskgroup stopped.") - async def stop(self): + async def stop(self) -> None: await self.taskgroup.cancel_remaining() - async def process_forward_queue(self): + async def process_forward_queue(self) -> None: while True: scheduled, expires, onion_packet, blinding, node_id = await self.forward_queue.get() if expires <= now(): @@ -450,13 +480,14 @@ class OnionMessageManager(Logger): ) except BaseException as e: self.logger.debug(f'error while sending {node_id=} e={e!r}') + # TODO: it is debatable whether we want to retry a forward. self.forward_queue.put_nowait((now() + self.FORWARD_RETRY_DELAY, expires, onion_packet, blinding, node_id)) def submit_forward( self, *, onion_packet: OnionPacket, blinding: bytes, - node_id: bytes): + node_id: bytes) -> None: if self.forward_queue.qsize() >= self.FORWARD_MAX_QUEUE: self.logger.debug('forward queue full, dropping packet') return @@ -464,7 +495,7 @@ class OnionMessageManager(Logger): queueitem = (now(), expires, onion_packet, blinding, node_id) self.forward_queue.put_nowait(queueitem) - async def process_send_queue(self): + async def process_send_queue(self) -> None: while True: scheduled, expires, key = await self.send_queue.get() req = self.pending.get(key) @@ -496,7 +527,7 @@ class OnionMessageManager(Logger): self.logger.debug(f'resubmit {key=}') self.send_queue.put_nowait((now() + self.REQUEST_REPLY_RETRY_DELAY, expires, key)) - def _remove_pending_message(self, key): + def _remove_pending_message(self, key: bytes) -> None: with self.pending_lock: if key in self.pending: del self.pending[key] @@ -536,13 +567,13 @@ class OnionMessageManager(Logger): task = asyncio.create_task(self._wait_task(key, req.future)) return task - async def _wait_task(self, key, future): + async def _wait_task(self, key: bytes, future: asyncio.Future): try: return await future finally: self._remove_pending_message(key) - def _send_pending_message(self, key): + def _send_pending_message(self, key: bytes) -> None: """adds reply_path to payload""" req = self.pending.get(key) payload = req.payload @@ -568,7 +599,7 @@ class OnionMessageManager(Logger): # TODO: use payload to determine prefix? return b'electrum' + key - def on_onion_message_received(self, recipient_data, payload): + def on_onion_message_received(self, recipient_data: dict, payload: dict) -> None: # we are destination, sanity checks # - if `encrypted_data_tlv` contains `allowed_features`: # - MUST ignore the message if: @@ -587,7 +618,7 @@ class OnionMessageManager(Logger): else: self.on_onion_message_received_reply(recipient_data, payload) - def on_onion_message_received_reply(self, recipient_data, payload): + def on_onion_message_received_reply(self, recipient_data: dict, payload: dict) -> None: # check if this reply is associated with a known request correl_data = recipient_data['path_id'].get('data') if not correl_data[:8] == b'electrum': @@ -600,11 +631,18 @@ class OnionMessageManager(Logger): return req.future.set_result((recipient_data, payload)) - def on_onion_message_received_unsolicited(self, recipient_data, payload): + def on_onion_message_received_unsolicited(self, recipient_data: dict, payload: dict) -> None: self.logger.debug('unsolicited onion_message received') self.logger.debug(f'payload: {payload!r}') - # TODO: currently only accepts simple text 'message' payload. + # This func currently only accepts simple text 'message' payload, a.k.a 'unknown_tag_1' + # in the bolt-4 test vectors. + # + # TODO: for BOLT-12, handle invoice_request here, which should correspond with a previously generated Offer. + # as this is not strictly part of BOLT-4, we should probably create a registration mechanism + # for various types of payloads, so we can let external code plug into onion messages + # e.g. via a decorator, something like + # @onion_message_request_handler(payload_key='invoice_request') for BOLT12 invoice requests. if 'message' not in payload: self.logger.error('Unsupported onion message payload') @@ -622,7 +660,13 @@ class OnionMessageManager(Logger): self.logger.info(f'onion message with text received: {text}') - def on_onion_message_forward(self, recipient_data, onion_packet, blinding, shared_secret): + def on_onion_message_forward( + self, + recipient_data: dict, + onion_packet: OnionPacket, + blinding: bytes, + shared_secret: bytes + ) -> None: if recipient_data.get('path_id'): self.logger.error('cannot forward onion_message, path_id in encrypted_data_tlv') return @@ -661,7 +705,8 @@ class OnionMessageManager(Logger): self.submit_forward(onion_packet=onion_packet, blinding=next_blinding, node_id=next_node_id) - def on_onion_message(self, payload): + def on_onion_message(self, payload: dict) -> None: + """handle arriving onion_message.""" blinding = payload.get('blinding') if not blinding: self.logger.error('missing blinding') @@ -676,7 +721,7 @@ class OnionMessageManager(Logger): onion_packet = OnionPacket.from_bytes(packet) self.process_onion_message_packet(blinding, onion_packet) - def process_onion_message_packet(self, blinding: bytes, onion_packet: OnionPacket): + def process_onion_message_packet(self, blinding: bytes, onion_packet: OnionPacket) -> None: our_privkey = blinding_privkey(self.lnwallet.node_keypair.privkey, blinding) processed_onion_packet = process_onion_packet(onion_packet, our_privkey, tlv_stream_name='onionmsg_tlv') payload = processed_onion_packet.hop_data.payload