From 7109c22317322ac29b01061c5074437f55be43f0 Mon Sep 17 00:00:00 2001 From: Sander van Grieken Date: Tue, 3 Dec 2024 15:58:10 +0100 Subject: [PATCH] unasync, no add_peer in create_onion_message_route_to, add manager tests --- electrum/commands.py | 2 +- electrum/onion_message.py | 81 +++++++++++------- tests/test_onion_message.py | 159 +++++++++++++++++++++++++++++++++++- 3 files changed, 209 insertions(+), 33 deletions(-) diff --git a/electrum/commands.py b/electrum/commands.py index a389dd6da..dcda79c46 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -1486,7 +1486,7 @@ class Commands(Logger): } try: - await send_onion_message_to(wallet.lnworker, node_id_or_blinded_path, destination_payload) + send_onion_message_to(wallet.lnworker, node_id_or_blinded_path, destination_payload) return {'success': True} except Exception as e: msg = str(e) diff --git a/electrum/onion_message.py b/electrum/onion_message.py index 1da74a7b2..018458ba2 100644 --- a/electrum/onion_message.py +++ b/electrum/onion_message.py @@ -46,6 +46,7 @@ if TYPE_CHECKING: from electrum.lnworker import LNWallet from electrum.network import Network from electrum.lnrouter import NodeInfo + from electrum.lntransport import LNPeerAddr from asyncio import Task logger = get_logger(__name__) @@ -59,6 +60,12 @@ FORWARD_RETRY_DELAY = 2 FORWARD_MAX_QUEUE = 3 +class NoRouteFound(Exception): + def __init__(self, *args, peer_address: 'LNPeerAddr' = None): + Exception.__init__(self, *args) + self.peer_address = peer_address + + def create_blinded_path(session_key: bytes, path: List[bytes], final_recipient_data: dict, *, hop_extras: Optional[Sequence[dict]] = None, dummy_hops: Optional[int] = 0) -> dict: @@ -135,7 +142,7 @@ def encrypt_onionmsg_tlv_hops_data(hops_data, hop_shared_secrets): hops_data[i].payload['encrypted_recipient_data'] = {'encrypted_recipient_data': encrypted_recipient_data} -async def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> List[PathEdge]: +def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> List[PathEdge]: """Constructs a route to the destination node_id, first by starting with peers with existing channels, and if no route found, opening a direct peer connection if node_id is found with an address in channel_db.""" @@ -145,7 +152,7 @@ async def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> chan.is_active() and not chan.is_frozen_for_sending()] my_sending_channels = {chan.short_channel_id: chan for chan in my_active_channels if chan.short_channel_id is not None} - # strat1: find route to introduction point over existing channel mesh + # find route to introduction point over existing channel mesh # NOTE: nodes that are in channel_db but are offline are not removed from the set if lnwallet.network.path_finder: if path := lnwallet.network.path_finder.find_path_for_payment( @@ -156,17 +163,19 @@ async def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> my_sending_channels=my_sending_channels ): return path - # strat2: dest node has host:port in channel_db? then open direct peer connection + # alt: dest is existing peer? + if lnwallet.peers.get(node_id): + return [PathEdge(short_channel_id=None, start_node=None, end_node=node_id)] + + # if we have an address, pass it. if lnwallet.channel_db: if peer_addr := lnwallet.channel_db.get_last_good_address(node_id): - peer = await lnwallet.add_peer(str(peer_addr)) - await peer.initialized - return [PathEdge(short_channel_id=None, start_node=None, end_node=node_id)] + raise NoRouteFound('no path found, peer_addr available', peer_address=peer_addr) - raise Exception('no path found') + raise NoRouteFound('no path found') -async def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes, destination_payload: dict, session_key: bytes = None): +def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes, destination_payload: dict, session_key: bytes = None): if session_key is None: session_key = os.urandom(32) @@ -226,7 +235,7 @@ async def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: b # start of blinded path is our peer blinding = blinded_path['blinding'] else: - path = await create_onion_message_route_to(lnwallet, introduction_point) + path = create_onion_message_route_to(lnwallet, introduction_point) # first edge must be to our peer peer = lnwallet.peers.get(path[0].end_node) @@ -303,7 +312,7 @@ async def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: b # destination is our direct peer, no need to route-find path = [PathEdge(short_channel_id=None, start_node=None, end_node=pubkey)] else: - path = await create_onion_message_route_to(lnwallet, pubkey) + path = create_onion_message_route_to(lnwallet, pubkey) # first edge must be to our peer peer = lnwallet.peers.get(path[0].end_node) @@ -340,9 +349,9 @@ async def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: b ) -async def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *, - max_paths: int = REQUEST_REPLY_PATHS_MAX, - preferred_node_id: bytes = None) -> List[dict]: +def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *, + max_paths: int = REQUEST_REPLY_PATHS_MAX, + preferred_node_id: bytes = None) -> List[dict]: # TODO: build longer paths and/or add dummy hops to increase privacy my_active_channels = [chan for chan in lnwallet.channels.values() if chan.is_active()] my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.peers.get(chan.node_id) and @@ -376,7 +385,9 @@ class OnionMessageManager(Logger): - association between onion message and their replies - manage re-send attempts, TODO: iterate through routes (both directions)""" - def __init__(self, lnwallet: 'LNWallet'): + def __init__(self, lnwallet: 'LNWallet', *, + request_reply_timeout=REQUEST_REPLY_TIMEOUT, + request_reply_retry_delay=REQUEST_REPLY_RETRY_DELAY): Logger.__init__(self) self.network = None # type: Optional['Network'] self.taskgroup = None # type: OldTaskGroup @@ -389,6 +400,9 @@ class OnionMessageManager(Logger): self.forwardqueue = queue.PriorityQueue() self.forwardqueue_notempty = asyncio.Event() + self.request_reply_timeout = request_reply_timeout + self.request_reply_retry_delay = request_reply_retry_delay + def start_network(self, *, network: 'Network'): assert network assert self.network is None, "already started" @@ -415,13 +429,13 @@ class OnionMessageManager(Logger): try: scheduled, expires, onion_packet, blinding, node_id = self.forwardqueue.get_nowait() except queue.Empty: - self.logger.debug(f'fwd queue empty') + self.logger.info(f'forward queue empty') self.forwardqueue_notempty.clear() await self.forwardqueue_notempty.wait() continue if expires <= now(): - self.logger.debug(f'fwd expired {node_id=}') + self.logger.debug(f'forward expired {node_id=}') continue if scheduled > now(): # return to queue @@ -448,7 +462,7 @@ class OnionMessageManager(Logger): blinding: bytes, node_id: bytes): if self.forwardqueue.qsize() >= FORWARD_MAX_QUEUE: - self.logger.debug('fwd queue full, dropping packet') + self.logger.debug('forward queue full, dropping packet') return expires = now() + FORWARD_RETRY_TIMEOUT queueitem = (now(), expires, onion_packet, blinding, node_id) @@ -460,9 +474,13 @@ class OnionMessageManager(Logger): try: scheduled, expires, key = self.requestreply_queue.get_nowait() except queue.Empty: - self.logger.debug(f'requestreply queue empty') + self.logger.info(f'requestreply queue empty') self.requestreply_queue_notempty.clear() - await self.requestreply_queue_notempty.wait() + try: + self.requestreply_queue_notempty.clear() + await self.requestreply_queue_notempty.wait() # NOTE: quirk, see note below + except Exception as e: + self.logger.info(f'Exception e={e!r}') continue requestreply = self.get_requestreply(key) @@ -483,12 +501,17 @@ class OnionMessageManager(Logger): continue try: - await self._send_pending_requestreply(key) + self._send_pending_requestreply(key) except BaseException as e: - self.logger.debug(f'error while sending {key=}') - self._set_requestreply_result(key, e) + self.logger.debug(f'error while sending {key=} {e!r}') + self._set_requestreply_result(key, copy.copy(e)) + # NOTE: above, when passing the caught exception instance e directly it leads to GeneratorExit() in + # queue_notempty.wait() later (??). pass a copy instead. + if isinstance(e, NoRouteFound) and e.peer_address: + await self.lnwallet.add_peer(str(e.peer_address)) else: - self.requestreply_queue.put_nowait((now() + REQUEST_REPLY_RETRY_DELAY, expires, key)) + self.logger.debug(f'resubmit {key=}') + self.requestreply_queue.put_nowait((now() + self.request_reply_retry_delay, expires, key)) def get_requestreply(self, key): with self.pending_lock: @@ -498,6 +521,7 @@ class OnionMessageManager(Logger): with self.pending_lock: requestreply = self.pending.get(key) if requestreply is None: + self.logger.error(f'requestreply with {key=} not found!') return self.pending[key]['result'] = result requestreply['ev'].set() @@ -537,7 +561,7 @@ class OnionMessageManager(Logger): } # tuple = (when to process, when it expires, key) - expires = now() + REQUEST_REPLY_TIMEOUT + expires = now() + self.request_reply_timeout queueitem = (now(), expires, key) self.requestreply_queue.put_nowait(queueitem) task = asyncio.create_task(self._requestreply_task(key)) @@ -560,12 +584,12 @@ class OnionMessageManager(Logger): assert requestreply result = requestreply.get('result') if isinstance(result, Exception): - raise result + raise result # raising in the task requires caller to explicitly extract exception. return result finally: self._remove_requestreply(key) - async def _send_pending_requestreply(self, key): + def _send_pending_requestreply(self, key): """adds reply_path to payload""" data = self.get_requestreply(key) payload = data.get('payload') @@ -577,7 +601,7 @@ class OnionMessageManager(Logger): if 'reply_path' not in final_payload: # unless explicitly set in payload, generate reply_path here path_id = self._path_id_from_payload_and_key(payload, key) - reply_paths = await get_blinded_reply_paths(self.lnwallet, path_id, max_paths=1) + reply_paths = get_blinded_reply_paths(self.lnwallet, path_id, max_paths=1) if not reply_paths: raise Exception(f'Could not create a reply_path for {key=}') @@ -585,10 +609,9 @@ class OnionMessageManager(Logger): # TODO: we should try alternate paths when retrying, this is currently not done. # (send_onion_message_to decides path, without knowledge of prev attempts) - await send_onion_message_to(self.lnwallet, node_id_or_blinded_path, final_payload) + send_onion_message_to(self.lnwallet, node_id_or_blinded_path, final_payload) def _path_id_from_payload_and_key(self, payload: dict, key: bytes) -> bytes: - # TODO: construct path_id in such a way that we can determine the request originated from us and is not spoofed # TODO: use payload to determine prefix? return b'electrum' + key diff --git a/tests/test_onion_message.py b/tests/test_onion_message.py index 03b37fdb5..0d58fe543 100644 --- a/tests/test_onion_message.py +++ b/tests/test_onion_message.py @@ -1,7 +1,11 @@ +import asyncio import io import os +import time +from functools import partial import electrum_ecc as ecc +from electrum_ecc import ECPrivkey from electrum.lnmsg import decode_msg, OnionWireSerializer from electrum.lnonion import ( @@ -10,10 +14,13 @@ from electrum.lnonion import ( get_shared_secrets_along_route, new_onion_packet, ONION_MESSAGE_LARGE_SIZE, HOPS_DATA_SIZE, InvalidPayloadSize) from electrum.crypto import get_ecdh -from electrum.onion_message import blinding_privkey, create_blinded_path, encrypt_onionmsg_tlv_hops_data -from electrum.util import bfh, read_json_file +from electrum.lnutil import LnFeatures +from electrum.onion_message import blinding_privkey, create_blinded_path, encrypt_onionmsg_tlv_hops_data, \ + OnionMessageManager, NoRouteFound, Timeout +from electrum.util import bfh, read_json_file, OldTaskGroup, get_asyncio_loop -from . import ElectrumTestCase +from . import ElectrumTestCase, test_lnpeer +from .test_lnpeer import PutIntoOthersQueueTransport, PeerInTests, keypair # test vectors https://github.com/lightning/bolts/pull/759/files path = os.path.join(os.path.dirname(__file__), 'blinded-onion-message-onion-test.json') @@ -232,3 +239,149 @@ class TestOnionMessage(ElectrumTestCase): encrypt_onionmsg_tlv_hops_data(hops_data, hop_shared_secrets) packet = new_onion_packet(payment_path_pubkeys, SESSION_KEY, hops_data, onion_message=True) self.assertEqual(packet.to_bytes(), ONION_MESSAGE_PACKET) + + +class MockNetwork: + def __init__(self): + self.asyncio_loop = get_asyncio_loop() + self.taskgroup = OldTaskGroup() + + +class MockWallet: + def __init__(self): + pass + + +class MockLNWallet(test_lnpeer.MockLNWallet): + + async def add_peer(self, connect_str: str): + t1 = PutIntoOthersQueueTransport(self.node_keypair, 'test') + p1 = PeerInTests(self, keypair().pubkey, t1) + self.peers[p1.pubkey] = p1 + p1.initialized.set_result(True) + return p1 + + +class MockPeer: + their_features = LnFeatures(LnFeatures.OPTION_ONION_MESSAGE_OPT) + + def __init__(self, pubkey, on_send_message=None): + self.pubkey = pubkey + self.on_send_message = on_send_message + + async def wait_one_htlc_switch_iteration(self, *args): + pass + + def send_message(self, *args, **kwargs): + if self.on_send_message: + self.on_send_message(*args, **kwargs) + + +class TestOnionMessageManager(ElectrumTestCase): + + def setUp(self): + super().setUp() + self.alice = ECPrivkey(privkey_bytes=b'\x41'*32) + self.alice_pub = self.alice.get_public_key_bytes(compressed=True) + self.bob = ECPrivkey(privkey_bytes=b'\x42'*32) + self.bob_pub = self.bob.get_public_key_bytes(compressed=True) + self.carol = ECPrivkey(privkey_bytes=b'\x43'*32) + self.carol_pub = self.carol.get_public_key_bytes(compressed=True) + self.dave = ECPrivkey(privkey_bytes=b'\x44'*32) + self.dave_pub = self.dave.get_public_key_bytes(compressed=True) + self.eve = ECPrivkey(privkey_bytes=b'\x45'*32) + self.eve_pub = self.eve.get_public_key_bytes(compressed=True) + + async def run_test1(self, t): + t1 = t.submit_requestreply( + payload={'message': {'text': 'alice_timeout'.encode('utf-8')}}, + node_id_or_blinded_path=self.alice_pub) + + with self.assertRaises(Timeout): + await t1 + + async def run_test2(self, t): + t2 = t.submit_requestreply( + payload={'message': {'text': 'bob_slow_timeout'.encode('utf-8')}}, + node_id_or_blinded_path=self.bob_pub) + + with self.assertRaises(Timeout): + await t2 + + async def run_test3(self, t, rkey): + t3 = t.submit_requestreply( + payload={'message': {'text': 'carol_with_immediate_reply'.encode('utf-8')}}, + node_id_or_blinded_path=self.carol_pub, + key=rkey) + + t3_result = await t3 + self.assertEqual(t3_result, ({'path_id': {'data': b'electrum' + rkey}}, {})) + + async def run_test4(self, t, rkey): + t4 = t.submit_requestreply( + payload={'message': {'text': 'dave_with_slow_reply'.encode('utf-8')}}, + node_id_or_blinded_path=self.dave_pub, + key=rkey) + + t4_result = await t4 + self.assertEqual(t4_result, ({'path_id': {'data': b'electrum' + rkey}}, {})) + + async def run_test5(self, t): + t5 = t.submit_requestreply( + payload={'message': {'text': 'no_peer'.encode('utf-8')}}, + node_id_or_blinded_path=self.eve_pub) + + with self.assertRaises(NoRouteFound): + await t5 + + async def test_manager(self): + n = MockNetwork() + k = keypair() + q1, q2 = asyncio.Queue(), asyncio.Queue() + lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test', has_anchors=False) + + def slow(*args, **kwargs): + time.sleep(2) + + def withreply(key, *args, **kwargs): + t.on_onion_message_received_reply({'path_id': {'data': b'electrum' + key}}, {}) + + def slowwithreply(key, *args, **kwargs): + time.sleep(2) + t.on_onion_message_received_reply({'path_id': {'data': b'electrum' + key}}, {}) + + rkey1 = bfh('0102030405060708') + rkey2 = bfh('0102030405060709') + + lnw.peers[self.alice_pub] = MockPeer(self.alice_pub) + lnw.peers[self.bob_pub] = MockPeer(self.bob_pub, on_send_message=slow) + lnw.peers[self.carol_pub] = MockPeer(self.carol_pub, on_send_message=partial(withreply, rkey1)) + lnw.peers[self.dave_pub] = MockPeer(self.dave_pub, on_send_message=partial(slowwithreply, rkey2)) + t = OnionMessageManager(lnw, request_reply_timeout=5, request_reply_retry_delay=1) + t.start_network(network=n) + + try: + await asyncio.sleep(1) + + self.logger.debug('tests in sequence') + + await self.run_test1(t) + await self.run_test2(t) + await self.run_test3(t, rkey1) + await self.run_test4(t, rkey2) + await self.run_test5(t) + + self.logger.debug('tests in parallel') + + async with OldTaskGroup() as group: + await group.spawn(self.run_test1(t)) + await group.spawn(self.run_test2(t)) + await group.spawn(self.run_test3(t, rkey1)) + await group.spawn(self.run_test4(t, rkey2)) + await group.spawn(self.run_test5(t)) + finally: + await asyncio.sleep(1) + + self.logger.debug('stopping manager') + await t.stop() + await lnw.stop()