1
0

onion_messages: add parameter typing and comments

This commit is contained in:
Sander van Grieken
2025-02-12 14:00:18 +01:00
parent 71b9761981
commit e216f1b324

View File

@@ -42,6 +42,7 @@ from electrum.lnonion import (get_bolt04_onion_key, OnionPacket, process_onion_p
from electrum.lnutil import LnFeatures from electrum.lnutil import LnFeatures
from electrum.util import OldTaskGroup, log_exceptions from electrum.util import OldTaskGroup, log_exceptions
# do not use util.now, because it rounds to integers # do not use util.now, because it rounds to integers
def now(): def now():
return time.time() return time.time()
@@ -66,9 +67,14 @@ class NoRouteFound(Exception):
self.peer_address = peer_address self.peer_address = peer_address
def create_blinded_path(session_key: bytes, path: List[bytes], final_recipient_data: dict, *, def create_blinded_path(
hop_extras: Optional[Sequence[dict]] = None, session_key: bytes,
dummy_hops: Optional[int] = 0) -> dict: path: List[bytes],
final_recipient_data: dict,
*,
hop_extras: Optional[Sequence[dict]] = None,
dummy_hops: Optional[int] = 0
) -> dict:
# dummy hops could be inserted anywhere in the path, but for compatibility just add them at the end # dummy hops could be inserted anywhere in the path, but for compatibility just add them at the end
# because blinded paths are usually constructed towards ourselves, and we know we can handle dummy hops. # because blinded paths are usually constructed towards ourselves, and we know we can handle dummy hops.
if dummy_hops: if dummy_hops:
@@ -114,7 +120,7 @@ def create_blinded_path(session_key: bytes, path: List[bytes], final_recipient_d
return blinded_path return blinded_path
def blinding_privkey(privkey, blinding): def blinding_privkey(privkey: bytes, blinding: bytes) -> bytes:
shared_secret = get_ecdh(privkey, blinding) shared_secret = get_ecdh(privkey, blinding)
b_hmac = get_bolt04_onion_key(b'blinded_node_id', shared_secret) b_hmac = get_bolt04_onion_key(b'blinded_node_id', shared_secret)
b_hmac_int = int.from_bytes(b_hmac, byteorder="big") b_hmac_int = int.from_bytes(b_hmac, byteorder="big")
@@ -126,13 +132,13 @@ def blinding_privkey(privkey, blinding):
return our_privkey return our_privkey
def is_onion_message_node(node_id: bytes, node_info: Optional['NodeInfo']): def is_onion_message_node(node_id: bytes, node_info: Optional['NodeInfo']) -> bool:
if not node_info: if not node_info:
return False return False
return LnFeatures(node_info.features).supports(LnFeatures.OPTION_ONION_MESSAGE_OPT) return LnFeatures(node_info.features).supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)
def encrypt_onionmsg_tlv_hops_data(hops_data, hop_shared_secrets): def encrypt_onionmsg_tlv_hops_data(hops_data: List[OnionHopsDataSingle], hop_shared_secrets: List[bytes]) -> None:
"""encrypt unencrypted onionmsg_tlv.encrypted_recipient_data for hops with blind_fields""" """encrypt unencrypted onionmsg_tlv.encrypted_recipient_data for hops with blind_fields"""
num_hops = len(hops_data) num_hops = len(hops_data)
for i in range(num_hops): for i in range(num_hops):
@@ -175,7 +181,12 @@ def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> List[
raise NoRouteFound('no path found') raise NoRouteFound('no path found')
def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes, destination_payload: dict, session_key: bytes = None): def send_onion_message_to(
lnwallet: 'LNWallet',
node_id_or_blinded_path: bytes,
destination_payload: dict,
session_key: bytes = None
) -> None:
if session_key is None: if session_key is None:
session_key = os.urandom(32) session_key = os.urandom(32)
@@ -350,9 +361,19 @@ def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes,
) )
def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *, def get_blinded_reply_paths(
max_paths: int = REQUEST_REPLY_PATHS_MAX, lnwallet: 'LNWallet',
preferred_node_id: bytes = None) -> List[dict]: path_id: bytes,
*,
max_paths: int = REQUEST_REPLY_PATHS_MAX,
preferred_node_id: bytes = None
) -> List[dict]:
"""construct a list of blinded reply_paths.
current logic:
- uses current onion_message capable channel peers if exist
- otherwise, uses current onion_message capable peers
- prefers preferred_node_id if given
- reply_path introduction points are direct peers only (TODO: longer reply paths)"""
# TODO: build longer paths and/or add dummy hops to increase privacy # TODO: build longer paths and/or add dummy hops to increase privacy
my_active_channels = [chan for chan in lnwallet.channels.values() if chan.is_active()] my_active_channels = [chan for chan in lnwallet.channels.values() if chan.is_active()]
my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.peers.get(chan.node_id) and my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.peers.get(chan.node_id) and
@@ -361,7 +382,7 @@ def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *,
result = [] result = []
mynodeid = lnwallet.node_keypair.pubkey mynodeid = lnwallet.node_keypair.pubkey
mydata = {'path_id': {'data': path_id}} # same used in every path mydata = {'path_id': {'data': path_id}} # same path_id used in every reply path
if len(my_onionmsg_channels): if len(my_onionmsg_channels):
# randomize list, but prefer preferred_node_id # randomize list, but prefer preferred_node_id
rchans = sorted(my_onionmsg_channels, key=lambda x: random() if x.node_id != preferred_node_id else 0) rchans = sorted(my_onionmsg_channels, key=lambda x: random() if x.node_id != preferred_node_id else 0)
@@ -380,6 +401,7 @@ def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *,
class Timeout(Exception): pass class Timeout(Exception): pass
class OnionMessageRequest(NamedTuple): class OnionMessageRequest(NamedTuple):
future: asyncio.Future future: asyncio.Future
payload: bytes payload: bytes
@@ -387,9 +409,17 @@ class OnionMessageRequest(NamedTuple):
class OnionMessageManager(Logger): class OnionMessageManager(Logger):
"""handle state around onion message sends and receives """handle state around onion message sends and receives.
- one instance per (ln)wallet
- association between onion message and their replies - association between onion message and their replies
- manage re-send attempts, TODO: iterate through routes (both directions)""" - manage re-send attempts while iterating over possible routes. Onion messages are unreliable
and fail silently if they don't reach their destination (or the reply gets dropped along the route back),
so the BOLT-4 spec suggests to send multiple messages, each with a different route to the introduction point).
- forwards are best-effort. They should not need retrying, but a queue is used to limit the pacing of forwarding,
and limiting the number of outstanding forwards. Any onion message forwards arriving when the forward queue
is full will be dropped.
TODO: iterate through routes for each request"""
SLEEP_DELAY = 1 SLEEP_DELAY = 1
REQUEST_REPLY_TIMEOUT = 30 REQUEST_REPLY_TIMEOUT = 30
@@ -403,12 +433,12 @@ class OnionMessageManager(Logger):
self.network = None # type: Optional['Network'] self.network = None # type: Optional['Network']
self.taskgroup = None # type: OldTaskGroup self.taskgroup = None # type: OldTaskGroup
self.lnwallet = lnwallet self.lnwallet = lnwallet
self.pending = {} self.pending = {} # type: dict[bytes, OnionMessageRequest]
self.pending_lock = threading.Lock() self.pending_lock = threading.Lock()
self.send_queue = asyncio.PriorityQueue() self.send_queue = asyncio.PriorityQueue()
self.forward_queue = asyncio.PriorityQueue() self.forward_queue = asyncio.PriorityQueue()
def start_network(self, *, network: 'Network'): def start_network(self, *, network: 'Network') -> None:
assert network assert network
assert self.network is None, "already started" assert self.network is None, "already started"
self.network = network self.network = network
@@ -416,17 +446,17 @@ class OnionMessageManager(Logger):
asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop) asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
@log_exceptions @log_exceptions
async def main_loop(self): async def main_loop(self) -> None:
self.logger.info("starting taskgroup.") self.logger.info("starting taskgroup.")
async with self.taskgroup as group: async with self.taskgroup as group:
await group.spawn(self.process_send_queue()) await group.spawn(self.process_send_queue())
await group.spawn(self.process_forward_queue()) await group.spawn(self.process_forward_queue())
self.logger.info("taskgroup stopped.") self.logger.info("taskgroup stopped.")
async def stop(self): async def stop(self) -> None:
await self.taskgroup.cancel_remaining() await self.taskgroup.cancel_remaining()
async def process_forward_queue(self): async def process_forward_queue(self) -> None:
while True: while True:
scheduled, expires, onion_packet, blinding, node_id = await self.forward_queue.get() scheduled, expires, onion_packet, blinding, node_id = await self.forward_queue.get()
if expires <= now(): if expires <= now():
@@ -450,13 +480,14 @@ class OnionMessageManager(Logger):
) )
except BaseException as e: except BaseException as e:
self.logger.debug(f'error while sending {node_id=} e={e!r}') self.logger.debug(f'error while sending {node_id=} e={e!r}')
# TODO: it is debatable whether we want to retry a forward.
self.forward_queue.put_nowait((now() + self.FORWARD_RETRY_DELAY, expires, onion_packet, blinding, node_id)) self.forward_queue.put_nowait((now() + self.FORWARD_RETRY_DELAY, expires, onion_packet, blinding, node_id))
def submit_forward( def submit_forward(
self, *, self, *,
onion_packet: OnionPacket, onion_packet: OnionPacket,
blinding: bytes, blinding: bytes,
node_id: bytes): node_id: bytes) -> None:
if self.forward_queue.qsize() >= self.FORWARD_MAX_QUEUE: if self.forward_queue.qsize() >= self.FORWARD_MAX_QUEUE:
self.logger.debug('forward queue full, dropping packet') self.logger.debug('forward queue full, dropping packet')
return return
@@ -464,7 +495,7 @@ class OnionMessageManager(Logger):
queueitem = (now(), expires, onion_packet, blinding, node_id) queueitem = (now(), expires, onion_packet, blinding, node_id)
self.forward_queue.put_nowait(queueitem) self.forward_queue.put_nowait(queueitem)
async def process_send_queue(self): async def process_send_queue(self) -> None:
while True: while True:
scheduled, expires, key = await self.send_queue.get() scheduled, expires, key = await self.send_queue.get()
req = self.pending.get(key) req = self.pending.get(key)
@@ -496,7 +527,7 @@ class OnionMessageManager(Logger):
self.logger.debug(f'resubmit {key=}') self.logger.debug(f'resubmit {key=}')
self.send_queue.put_nowait((now() + self.REQUEST_REPLY_RETRY_DELAY, expires, key)) self.send_queue.put_nowait((now() + self.REQUEST_REPLY_RETRY_DELAY, expires, key))
def _remove_pending_message(self, key): def _remove_pending_message(self, key: bytes) -> None:
with self.pending_lock: with self.pending_lock:
if key in self.pending: if key in self.pending:
del self.pending[key] del self.pending[key]
@@ -536,13 +567,13 @@ class OnionMessageManager(Logger):
task = asyncio.create_task(self._wait_task(key, req.future)) task = asyncio.create_task(self._wait_task(key, req.future))
return task return task
async def _wait_task(self, key, future): async def _wait_task(self, key: bytes, future: asyncio.Future):
try: try:
return await future return await future
finally: finally:
self._remove_pending_message(key) self._remove_pending_message(key)
def _send_pending_message(self, key): def _send_pending_message(self, key: bytes) -> None:
"""adds reply_path to payload""" """adds reply_path to payload"""
req = self.pending.get(key) req = self.pending.get(key)
payload = req.payload payload = req.payload
@@ -568,7 +599,7 @@ class OnionMessageManager(Logger):
# TODO: use payload to determine prefix? # TODO: use payload to determine prefix?
return b'electrum' + key return b'electrum' + key
def on_onion_message_received(self, recipient_data, payload): def on_onion_message_received(self, recipient_data: dict, payload: dict) -> None:
# we are destination, sanity checks # we are destination, sanity checks
# - if `encrypted_data_tlv` contains `allowed_features`: # - if `encrypted_data_tlv` contains `allowed_features`:
# - MUST ignore the message if: # - MUST ignore the message if:
@@ -587,7 +618,7 @@ class OnionMessageManager(Logger):
else: else:
self.on_onion_message_received_reply(recipient_data, payload) self.on_onion_message_received_reply(recipient_data, payload)
def on_onion_message_received_reply(self, recipient_data, payload): def on_onion_message_received_reply(self, recipient_data: dict, payload: dict) -> None:
# check if this reply is associated with a known request # check if this reply is associated with a known request
correl_data = recipient_data['path_id'].get('data') correl_data = recipient_data['path_id'].get('data')
if not correl_data[:8] == b'electrum': if not correl_data[:8] == b'electrum':
@@ -600,11 +631,18 @@ class OnionMessageManager(Logger):
return return
req.future.set_result((recipient_data, payload)) req.future.set_result((recipient_data, payload))
def on_onion_message_received_unsolicited(self, recipient_data, payload): def on_onion_message_received_unsolicited(self, recipient_data: dict, payload: dict) -> None:
self.logger.debug('unsolicited onion_message received') self.logger.debug('unsolicited onion_message received')
self.logger.debug(f'payload: {payload!r}') self.logger.debug(f'payload: {payload!r}')
# TODO: currently only accepts simple text 'message' payload. # This func currently only accepts simple text 'message' payload, a.k.a 'unknown_tag_1'
# in the bolt-4 test vectors.
#
# TODO: for BOLT-12, handle invoice_request here, which should correspond with a previously generated Offer.
# as this is not strictly part of BOLT-4, we should probably create a registration mechanism
# for various types of payloads, so we can let external code plug into onion messages
# e.g. via a decorator, something like
# @onion_message_request_handler(payload_key='invoice_request') for BOLT12 invoice requests.
if 'message' not in payload: if 'message' not in payload:
self.logger.error('Unsupported onion message payload') self.logger.error('Unsupported onion message payload')
@@ -622,7 +660,13 @@ class OnionMessageManager(Logger):
self.logger.info(f'onion message with text received: {text}') self.logger.info(f'onion message with text received: {text}')
def on_onion_message_forward(self, recipient_data, onion_packet, blinding, shared_secret): def on_onion_message_forward(
self,
recipient_data: dict,
onion_packet: OnionPacket,
blinding: bytes,
shared_secret: bytes
) -> None:
if recipient_data.get('path_id'): if recipient_data.get('path_id'):
self.logger.error('cannot forward onion_message, path_id in encrypted_data_tlv') self.logger.error('cannot forward onion_message, path_id in encrypted_data_tlv')
return return
@@ -661,7 +705,8 @@ class OnionMessageManager(Logger):
self.submit_forward(onion_packet=onion_packet, blinding=next_blinding, node_id=next_node_id) self.submit_forward(onion_packet=onion_packet, blinding=next_blinding, node_id=next_node_id)
def on_onion_message(self, payload): def on_onion_message(self, payload: dict) -> None:
"""handle arriving onion_message."""
blinding = payload.get('blinding') blinding = payload.get('blinding')
if not blinding: if not blinding:
self.logger.error('missing blinding') self.logger.error('missing blinding')
@@ -676,7 +721,7 @@ class OnionMessageManager(Logger):
onion_packet = OnionPacket.from_bytes(packet) onion_packet = OnionPacket.from_bytes(packet)
self.process_onion_message_packet(blinding, onion_packet) self.process_onion_message_packet(blinding, onion_packet)
def process_onion_message_packet(self, blinding: bytes, onion_packet: OnionPacket): def process_onion_message_packet(self, blinding: bytes, onion_packet: OnionPacket) -> None:
our_privkey = blinding_privkey(self.lnwallet.node_keypair.privkey, blinding) our_privkey = blinding_privkey(self.lnwallet.node_keypair.privkey, blinding)
processed_onion_packet = process_onion_packet(onion_packet, our_privkey, tlv_stream_name='onionmsg_tlv') processed_onion_packet = process_onion_packet(onion_packet, our_privkey, tlv_stream_name='onionmsg_tlv')
payload = processed_onion_packet.hop_data.payload payload = processed_onion_packet.hop_data.payload