onion_messages: add parameter typing and comments
This commit is contained in:
@@ -42,6 +42,7 @@ from electrum.lnonion import (get_bolt04_onion_key, OnionPacket, process_onion_p
|
|||||||
from electrum.lnutil import LnFeatures
|
from electrum.lnutil import LnFeatures
|
||||||
from electrum.util import OldTaskGroup, log_exceptions
|
from electrum.util import OldTaskGroup, log_exceptions
|
||||||
|
|
||||||
|
|
||||||
# do not use util.now, because it rounds to integers
|
# do not use util.now, because it rounds to integers
|
||||||
def now():
|
def now():
|
||||||
return time.time()
|
return time.time()
|
||||||
@@ -66,9 +67,14 @@ class NoRouteFound(Exception):
|
|||||||
self.peer_address = peer_address
|
self.peer_address = peer_address
|
||||||
|
|
||||||
|
|
||||||
def create_blinded_path(session_key: bytes, path: List[bytes], final_recipient_data: dict, *,
|
def create_blinded_path(
|
||||||
hop_extras: Optional[Sequence[dict]] = None,
|
session_key: bytes,
|
||||||
dummy_hops: Optional[int] = 0) -> dict:
|
path: List[bytes],
|
||||||
|
final_recipient_data: dict,
|
||||||
|
*,
|
||||||
|
hop_extras: Optional[Sequence[dict]] = None,
|
||||||
|
dummy_hops: Optional[int] = 0
|
||||||
|
) -> dict:
|
||||||
# dummy hops could be inserted anywhere in the path, but for compatibility just add them at the end
|
# dummy hops could be inserted anywhere in the path, but for compatibility just add them at the end
|
||||||
# because blinded paths are usually constructed towards ourselves, and we know we can handle dummy hops.
|
# because blinded paths are usually constructed towards ourselves, and we know we can handle dummy hops.
|
||||||
if dummy_hops:
|
if dummy_hops:
|
||||||
@@ -114,7 +120,7 @@ def create_blinded_path(session_key: bytes, path: List[bytes], final_recipient_d
|
|||||||
return blinded_path
|
return blinded_path
|
||||||
|
|
||||||
|
|
||||||
def blinding_privkey(privkey, blinding):
|
def blinding_privkey(privkey: bytes, blinding: bytes) -> bytes:
|
||||||
shared_secret = get_ecdh(privkey, blinding)
|
shared_secret = get_ecdh(privkey, blinding)
|
||||||
b_hmac = get_bolt04_onion_key(b'blinded_node_id', shared_secret)
|
b_hmac = get_bolt04_onion_key(b'blinded_node_id', shared_secret)
|
||||||
b_hmac_int = int.from_bytes(b_hmac, byteorder="big")
|
b_hmac_int = int.from_bytes(b_hmac, byteorder="big")
|
||||||
@@ -126,13 +132,13 @@ def blinding_privkey(privkey, blinding):
|
|||||||
return our_privkey
|
return our_privkey
|
||||||
|
|
||||||
|
|
||||||
def is_onion_message_node(node_id: bytes, node_info: Optional['NodeInfo']):
|
def is_onion_message_node(node_id: bytes, node_info: Optional['NodeInfo']) -> bool:
|
||||||
if not node_info:
|
if not node_info:
|
||||||
return False
|
return False
|
||||||
return LnFeatures(node_info.features).supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)
|
return LnFeatures(node_info.features).supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)
|
||||||
|
|
||||||
|
|
||||||
def encrypt_onionmsg_tlv_hops_data(hops_data, hop_shared_secrets):
|
def encrypt_onionmsg_tlv_hops_data(hops_data: List[OnionHopsDataSingle], hop_shared_secrets: List[bytes]) -> None:
|
||||||
"""encrypt unencrypted onionmsg_tlv.encrypted_recipient_data for hops with blind_fields"""
|
"""encrypt unencrypted onionmsg_tlv.encrypted_recipient_data for hops with blind_fields"""
|
||||||
num_hops = len(hops_data)
|
num_hops = len(hops_data)
|
||||||
for i in range(num_hops):
|
for i in range(num_hops):
|
||||||
@@ -175,7 +181,12 @@ def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> List[
|
|||||||
raise NoRouteFound('no path found')
|
raise NoRouteFound('no path found')
|
||||||
|
|
||||||
|
|
||||||
def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes, destination_payload: dict, session_key: bytes = None):
|
def send_onion_message_to(
|
||||||
|
lnwallet: 'LNWallet',
|
||||||
|
node_id_or_blinded_path: bytes,
|
||||||
|
destination_payload: dict,
|
||||||
|
session_key: bytes = None
|
||||||
|
) -> None:
|
||||||
if session_key is None:
|
if session_key is None:
|
||||||
session_key = os.urandom(32)
|
session_key = os.urandom(32)
|
||||||
|
|
||||||
@@ -350,9 +361,19 @@ def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *,
|
def get_blinded_reply_paths(
|
||||||
max_paths: int = REQUEST_REPLY_PATHS_MAX,
|
lnwallet: 'LNWallet',
|
||||||
preferred_node_id: bytes = None) -> List[dict]:
|
path_id: bytes,
|
||||||
|
*,
|
||||||
|
max_paths: int = REQUEST_REPLY_PATHS_MAX,
|
||||||
|
preferred_node_id: bytes = None
|
||||||
|
) -> List[dict]:
|
||||||
|
"""construct a list of blinded reply_paths.
|
||||||
|
current logic:
|
||||||
|
- uses current onion_message capable channel peers if exist
|
||||||
|
- otherwise, uses current onion_message capable peers
|
||||||
|
- prefers preferred_node_id if given
|
||||||
|
- reply_path introduction points are direct peers only (TODO: longer reply paths)"""
|
||||||
# TODO: build longer paths and/or add dummy hops to increase privacy
|
# TODO: build longer paths and/or add dummy hops to increase privacy
|
||||||
my_active_channels = [chan for chan in lnwallet.channels.values() if chan.is_active()]
|
my_active_channels = [chan for chan in lnwallet.channels.values() if chan.is_active()]
|
||||||
my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.peers.get(chan.node_id) and
|
my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.peers.get(chan.node_id) and
|
||||||
@@ -361,7 +382,7 @@ def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *,
|
|||||||
|
|
||||||
result = []
|
result = []
|
||||||
mynodeid = lnwallet.node_keypair.pubkey
|
mynodeid = lnwallet.node_keypair.pubkey
|
||||||
mydata = {'path_id': {'data': path_id}} # same used in every path
|
mydata = {'path_id': {'data': path_id}} # same path_id used in every reply path
|
||||||
if len(my_onionmsg_channels):
|
if len(my_onionmsg_channels):
|
||||||
# randomize list, but prefer preferred_node_id
|
# randomize list, but prefer preferred_node_id
|
||||||
rchans = sorted(my_onionmsg_channels, key=lambda x: random() if x.node_id != preferred_node_id else 0)
|
rchans = sorted(my_onionmsg_channels, key=lambda x: random() if x.node_id != preferred_node_id else 0)
|
||||||
@@ -380,6 +401,7 @@ def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *,
|
|||||||
|
|
||||||
class Timeout(Exception): pass
|
class Timeout(Exception): pass
|
||||||
|
|
||||||
|
|
||||||
class OnionMessageRequest(NamedTuple):
|
class OnionMessageRequest(NamedTuple):
|
||||||
future: asyncio.Future
|
future: asyncio.Future
|
||||||
payload: bytes
|
payload: bytes
|
||||||
@@ -387,9 +409,17 @@ class OnionMessageRequest(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class OnionMessageManager(Logger):
|
class OnionMessageManager(Logger):
|
||||||
"""handle state around onion message sends and receives
|
"""handle state around onion message sends and receives.
|
||||||
|
- one instance per (ln)wallet
|
||||||
- association between onion message and their replies
|
- association between onion message and their replies
|
||||||
- manage re-send attempts, TODO: iterate through routes (both directions)"""
|
- manage re-send attempts while iterating over possible routes. Onion messages are unreliable
|
||||||
|
and fail silently if they don't reach their destination (or the reply gets dropped along the route back),
|
||||||
|
so the BOLT-4 spec suggests to send multiple messages, each with a different route to the introduction point).
|
||||||
|
- forwards are best-effort. They should not need retrying, but a queue is used to limit the pacing of forwarding,
|
||||||
|
and limiting the number of outstanding forwards. Any onion message forwards arriving when the forward queue
|
||||||
|
is full will be dropped.
|
||||||
|
|
||||||
|
TODO: iterate through routes for each request"""
|
||||||
|
|
||||||
SLEEP_DELAY = 1
|
SLEEP_DELAY = 1
|
||||||
REQUEST_REPLY_TIMEOUT = 30
|
REQUEST_REPLY_TIMEOUT = 30
|
||||||
@@ -403,12 +433,12 @@ class OnionMessageManager(Logger):
|
|||||||
self.network = None # type: Optional['Network']
|
self.network = None # type: Optional['Network']
|
||||||
self.taskgroup = None # type: OldTaskGroup
|
self.taskgroup = None # type: OldTaskGroup
|
||||||
self.lnwallet = lnwallet
|
self.lnwallet = lnwallet
|
||||||
self.pending = {}
|
self.pending = {} # type: dict[bytes, OnionMessageRequest]
|
||||||
self.pending_lock = threading.Lock()
|
self.pending_lock = threading.Lock()
|
||||||
self.send_queue = asyncio.PriorityQueue()
|
self.send_queue = asyncio.PriorityQueue()
|
||||||
self.forward_queue = asyncio.PriorityQueue()
|
self.forward_queue = asyncio.PriorityQueue()
|
||||||
|
|
||||||
def start_network(self, *, network: 'Network'):
|
def start_network(self, *, network: 'Network') -> None:
|
||||||
assert network
|
assert network
|
||||||
assert self.network is None, "already started"
|
assert self.network is None, "already started"
|
||||||
self.network = network
|
self.network = network
|
||||||
@@ -416,17 +446,17 @@ class OnionMessageManager(Logger):
|
|||||||
asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
|
asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
|
||||||
|
|
||||||
@log_exceptions
|
@log_exceptions
|
||||||
async def main_loop(self):
|
async def main_loop(self) -> None:
|
||||||
self.logger.info("starting taskgroup.")
|
self.logger.info("starting taskgroup.")
|
||||||
async with self.taskgroup as group:
|
async with self.taskgroup as group:
|
||||||
await group.spawn(self.process_send_queue())
|
await group.spawn(self.process_send_queue())
|
||||||
await group.spawn(self.process_forward_queue())
|
await group.spawn(self.process_forward_queue())
|
||||||
self.logger.info("taskgroup stopped.")
|
self.logger.info("taskgroup stopped.")
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self) -> None:
|
||||||
await self.taskgroup.cancel_remaining()
|
await self.taskgroup.cancel_remaining()
|
||||||
|
|
||||||
async def process_forward_queue(self):
|
async def process_forward_queue(self) -> None:
|
||||||
while True:
|
while True:
|
||||||
scheduled, expires, onion_packet, blinding, node_id = await self.forward_queue.get()
|
scheduled, expires, onion_packet, blinding, node_id = await self.forward_queue.get()
|
||||||
if expires <= now():
|
if expires <= now():
|
||||||
@@ -450,13 +480,14 @@ class OnionMessageManager(Logger):
|
|||||||
)
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
self.logger.debug(f'error while sending {node_id=} e={e!r}')
|
self.logger.debug(f'error while sending {node_id=} e={e!r}')
|
||||||
|
# TODO: it is debatable whether we want to retry a forward.
|
||||||
self.forward_queue.put_nowait((now() + self.FORWARD_RETRY_DELAY, expires, onion_packet, blinding, node_id))
|
self.forward_queue.put_nowait((now() + self.FORWARD_RETRY_DELAY, expires, onion_packet, blinding, node_id))
|
||||||
|
|
||||||
def submit_forward(
|
def submit_forward(
|
||||||
self, *,
|
self, *,
|
||||||
onion_packet: OnionPacket,
|
onion_packet: OnionPacket,
|
||||||
blinding: bytes,
|
blinding: bytes,
|
||||||
node_id: bytes):
|
node_id: bytes) -> None:
|
||||||
if self.forward_queue.qsize() >= self.FORWARD_MAX_QUEUE:
|
if self.forward_queue.qsize() >= self.FORWARD_MAX_QUEUE:
|
||||||
self.logger.debug('forward queue full, dropping packet')
|
self.logger.debug('forward queue full, dropping packet')
|
||||||
return
|
return
|
||||||
@@ -464,7 +495,7 @@ class OnionMessageManager(Logger):
|
|||||||
queueitem = (now(), expires, onion_packet, blinding, node_id)
|
queueitem = (now(), expires, onion_packet, blinding, node_id)
|
||||||
self.forward_queue.put_nowait(queueitem)
|
self.forward_queue.put_nowait(queueitem)
|
||||||
|
|
||||||
async def process_send_queue(self):
|
async def process_send_queue(self) -> None:
|
||||||
while True:
|
while True:
|
||||||
scheduled, expires, key = await self.send_queue.get()
|
scheduled, expires, key = await self.send_queue.get()
|
||||||
req = self.pending.get(key)
|
req = self.pending.get(key)
|
||||||
@@ -496,7 +527,7 @@ class OnionMessageManager(Logger):
|
|||||||
self.logger.debug(f'resubmit {key=}')
|
self.logger.debug(f'resubmit {key=}')
|
||||||
self.send_queue.put_nowait((now() + self.REQUEST_REPLY_RETRY_DELAY, expires, key))
|
self.send_queue.put_nowait((now() + self.REQUEST_REPLY_RETRY_DELAY, expires, key))
|
||||||
|
|
||||||
def _remove_pending_message(self, key):
|
def _remove_pending_message(self, key: bytes) -> None:
|
||||||
with self.pending_lock:
|
with self.pending_lock:
|
||||||
if key in self.pending:
|
if key in self.pending:
|
||||||
del self.pending[key]
|
del self.pending[key]
|
||||||
@@ -536,13 +567,13 @@ class OnionMessageManager(Logger):
|
|||||||
task = asyncio.create_task(self._wait_task(key, req.future))
|
task = asyncio.create_task(self._wait_task(key, req.future))
|
||||||
return task
|
return task
|
||||||
|
|
||||||
async def _wait_task(self, key, future):
|
async def _wait_task(self, key: bytes, future: asyncio.Future):
|
||||||
try:
|
try:
|
||||||
return await future
|
return await future
|
||||||
finally:
|
finally:
|
||||||
self._remove_pending_message(key)
|
self._remove_pending_message(key)
|
||||||
|
|
||||||
def _send_pending_message(self, key):
|
def _send_pending_message(self, key: bytes) -> None:
|
||||||
"""adds reply_path to payload"""
|
"""adds reply_path to payload"""
|
||||||
req = self.pending.get(key)
|
req = self.pending.get(key)
|
||||||
payload = req.payload
|
payload = req.payload
|
||||||
@@ -568,7 +599,7 @@ class OnionMessageManager(Logger):
|
|||||||
# TODO: use payload to determine prefix?
|
# TODO: use payload to determine prefix?
|
||||||
return b'electrum' + key
|
return b'electrum' + key
|
||||||
|
|
||||||
def on_onion_message_received(self, recipient_data, payload):
|
def on_onion_message_received(self, recipient_data: dict, payload: dict) -> None:
|
||||||
# we are destination, sanity checks
|
# we are destination, sanity checks
|
||||||
# - if `encrypted_data_tlv` contains `allowed_features`:
|
# - if `encrypted_data_tlv` contains `allowed_features`:
|
||||||
# - MUST ignore the message if:
|
# - MUST ignore the message if:
|
||||||
@@ -587,7 +618,7 @@ class OnionMessageManager(Logger):
|
|||||||
else:
|
else:
|
||||||
self.on_onion_message_received_reply(recipient_data, payload)
|
self.on_onion_message_received_reply(recipient_data, payload)
|
||||||
|
|
||||||
def on_onion_message_received_reply(self, recipient_data, payload):
|
def on_onion_message_received_reply(self, recipient_data: dict, payload: dict) -> None:
|
||||||
# check if this reply is associated with a known request
|
# check if this reply is associated with a known request
|
||||||
correl_data = recipient_data['path_id'].get('data')
|
correl_data = recipient_data['path_id'].get('data')
|
||||||
if not correl_data[:8] == b'electrum':
|
if not correl_data[:8] == b'electrum':
|
||||||
@@ -600,11 +631,18 @@ class OnionMessageManager(Logger):
|
|||||||
return
|
return
|
||||||
req.future.set_result((recipient_data, payload))
|
req.future.set_result((recipient_data, payload))
|
||||||
|
|
||||||
def on_onion_message_received_unsolicited(self, recipient_data, payload):
|
def on_onion_message_received_unsolicited(self, recipient_data: dict, payload: dict) -> None:
|
||||||
self.logger.debug('unsolicited onion_message received')
|
self.logger.debug('unsolicited onion_message received')
|
||||||
self.logger.debug(f'payload: {payload!r}')
|
self.logger.debug(f'payload: {payload!r}')
|
||||||
|
|
||||||
# TODO: currently only accepts simple text 'message' payload.
|
# This func currently only accepts simple text 'message' payload, a.k.a 'unknown_tag_1'
|
||||||
|
# in the bolt-4 test vectors.
|
||||||
|
#
|
||||||
|
# TODO: for BOLT-12, handle invoice_request here, which should correspond with a previously generated Offer.
|
||||||
|
# as this is not strictly part of BOLT-4, we should probably create a registration mechanism
|
||||||
|
# for various types of payloads, so we can let external code plug into onion messages
|
||||||
|
# e.g. via a decorator, something like
|
||||||
|
# @onion_message_request_handler(payload_key='invoice_request') for BOLT12 invoice requests.
|
||||||
|
|
||||||
if 'message' not in payload:
|
if 'message' not in payload:
|
||||||
self.logger.error('Unsupported onion message payload')
|
self.logger.error('Unsupported onion message payload')
|
||||||
@@ -622,7 +660,13 @@ class OnionMessageManager(Logger):
|
|||||||
|
|
||||||
self.logger.info(f'onion message with text received: {text}')
|
self.logger.info(f'onion message with text received: {text}')
|
||||||
|
|
||||||
def on_onion_message_forward(self, recipient_data, onion_packet, blinding, shared_secret):
|
def on_onion_message_forward(
|
||||||
|
self,
|
||||||
|
recipient_data: dict,
|
||||||
|
onion_packet: OnionPacket,
|
||||||
|
blinding: bytes,
|
||||||
|
shared_secret: bytes
|
||||||
|
) -> None:
|
||||||
if recipient_data.get('path_id'):
|
if recipient_data.get('path_id'):
|
||||||
self.logger.error('cannot forward onion_message, path_id in encrypted_data_tlv')
|
self.logger.error('cannot forward onion_message, path_id in encrypted_data_tlv')
|
||||||
return
|
return
|
||||||
@@ -661,7 +705,8 @@ class OnionMessageManager(Logger):
|
|||||||
|
|
||||||
self.submit_forward(onion_packet=onion_packet, blinding=next_blinding, node_id=next_node_id)
|
self.submit_forward(onion_packet=onion_packet, blinding=next_blinding, node_id=next_node_id)
|
||||||
|
|
||||||
def on_onion_message(self, payload):
|
def on_onion_message(self, payload: dict) -> None:
|
||||||
|
"""handle arriving onion_message."""
|
||||||
blinding = payload.get('blinding')
|
blinding = payload.get('blinding')
|
||||||
if not blinding:
|
if not blinding:
|
||||||
self.logger.error('missing blinding')
|
self.logger.error('missing blinding')
|
||||||
@@ -676,7 +721,7 @@ class OnionMessageManager(Logger):
|
|||||||
onion_packet = OnionPacket.from_bytes(packet)
|
onion_packet = OnionPacket.from_bytes(packet)
|
||||||
self.process_onion_message_packet(blinding, onion_packet)
|
self.process_onion_message_packet(blinding, onion_packet)
|
||||||
|
|
||||||
def process_onion_message_packet(self, blinding: bytes, onion_packet: OnionPacket):
|
def process_onion_message_packet(self, blinding: bytes, onion_packet: OnionPacket) -> None:
|
||||||
our_privkey = blinding_privkey(self.lnwallet.node_keypair.privkey, blinding)
|
our_privkey = blinding_privkey(self.lnwallet.node_keypair.privkey, blinding)
|
||||||
processed_onion_packet = process_onion_packet(onion_packet, our_privkey, tlv_stream_name='onionmsg_tlv')
|
processed_onion_packet = process_onion_packet(onion_packet, our_privkey, tlv_stream_name='onionmsg_tlv')
|
||||||
payload = processed_onion_packet.hop_data.payload
|
payload = processed_onion_packet.hop_data.payload
|
||||||
|
|||||||
Reference in New Issue
Block a user