unasync, no add_peer in create_onion_message_route_to, add manager tests
This commit is contained in:
@@ -1486,7 +1486,7 @@ class Commands(Logger):
|
||||
}
|
||||
|
||||
try:
|
||||
await send_onion_message_to(wallet.lnworker, node_id_or_blinded_path, destination_payload)
|
||||
send_onion_message_to(wallet.lnworker, node_id_or_blinded_path, destination_payload)
|
||||
return {'success': True}
|
||||
except Exception as e:
|
||||
msg = str(e)
|
||||
|
||||
@@ -46,6 +46,7 @@ if TYPE_CHECKING:
|
||||
from electrum.lnworker import LNWallet
|
||||
from electrum.network import Network
|
||||
from electrum.lnrouter import NodeInfo
|
||||
from electrum.lntransport import LNPeerAddr
|
||||
from asyncio import Task
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -59,6 +60,12 @@ FORWARD_RETRY_DELAY = 2
|
||||
FORWARD_MAX_QUEUE = 3
|
||||
|
||||
|
||||
class NoRouteFound(Exception):
|
||||
def __init__(self, *args, peer_address: 'LNPeerAddr' = None):
|
||||
Exception.__init__(self, *args)
|
||||
self.peer_address = peer_address
|
||||
|
||||
|
||||
def create_blinded_path(session_key: bytes, path: List[bytes], final_recipient_data: dict, *,
|
||||
hop_extras: Optional[Sequence[dict]] = None,
|
||||
dummy_hops: Optional[int] = 0) -> dict:
|
||||
@@ -135,7 +142,7 @@ def encrypt_onionmsg_tlv_hops_data(hops_data, hop_shared_secrets):
|
||||
hops_data[i].payload['encrypted_recipient_data'] = {'encrypted_recipient_data': encrypted_recipient_data}
|
||||
|
||||
|
||||
async def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> List[PathEdge]:
|
||||
def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> List[PathEdge]:
|
||||
"""Constructs a route to the destination node_id, first by starting with peers with existing channels,
|
||||
and if no route found, opening a direct peer connection if node_id is found with an address in
|
||||
channel_db."""
|
||||
@@ -145,7 +152,7 @@ async def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) ->
|
||||
chan.is_active() and not chan.is_frozen_for_sending()]
|
||||
my_sending_channels = {chan.short_channel_id: chan for chan in my_active_channels
|
||||
if chan.short_channel_id is not None}
|
||||
# strat1: find route to introduction point over existing channel mesh
|
||||
# find route to introduction point over existing channel mesh
|
||||
# NOTE: nodes that are in channel_db but are offline are not removed from the set
|
||||
if lnwallet.network.path_finder:
|
||||
if path := lnwallet.network.path_finder.find_path_for_payment(
|
||||
@@ -156,17 +163,19 @@ async def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) ->
|
||||
my_sending_channels=my_sending_channels
|
||||
): return path
|
||||
|
||||
# strat2: dest node has host:port in channel_db? then open direct peer connection
|
||||
# alt: dest is existing peer?
|
||||
if lnwallet.peers.get(node_id):
|
||||
return [PathEdge(short_channel_id=None, start_node=None, end_node=node_id)]
|
||||
|
||||
# if we have an address, pass it.
|
||||
if lnwallet.channel_db:
|
||||
if peer_addr := lnwallet.channel_db.get_last_good_address(node_id):
|
||||
peer = await lnwallet.add_peer(str(peer_addr))
|
||||
await peer.initialized
|
||||
return [PathEdge(short_channel_id=None, start_node=None, end_node=node_id)]
|
||||
raise NoRouteFound('no path found, peer_addr available', peer_address=peer_addr)
|
||||
|
||||
raise Exception('no path found')
|
||||
raise NoRouteFound('no path found')
|
||||
|
||||
|
||||
async def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes, destination_payload: dict, session_key: bytes = None):
|
||||
def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: bytes, destination_payload: dict, session_key: bytes = None):
|
||||
if session_key is None:
|
||||
session_key = os.urandom(32)
|
||||
|
||||
@@ -226,7 +235,7 @@ async def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: b
|
||||
# start of blinded path is our peer
|
||||
blinding = blinded_path['blinding']
|
||||
else:
|
||||
path = await create_onion_message_route_to(lnwallet, introduction_point)
|
||||
path = create_onion_message_route_to(lnwallet, introduction_point)
|
||||
|
||||
# first edge must be to our peer
|
||||
peer = lnwallet.peers.get(path[0].end_node)
|
||||
@@ -303,7 +312,7 @@ async def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: b
|
||||
# destination is our direct peer, no need to route-find
|
||||
path = [PathEdge(short_channel_id=None, start_node=None, end_node=pubkey)]
|
||||
else:
|
||||
path = await create_onion_message_route_to(lnwallet, pubkey)
|
||||
path = create_onion_message_route_to(lnwallet, pubkey)
|
||||
|
||||
# first edge must be to our peer
|
||||
peer = lnwallet.peers.get(path[0].end_node)
|
||||
@@ -340,9 +349,9 @@ async def send_onion_message_to(lnwallet: 'LNWallet', node_id_or_blinded_path: b
|
||||
)
|
||||
|
||||
|
||||
async def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *,
|
||||
max_paths: int = REQUEST_REPLY_PATHS_MAX,
|
||||
preferred_node_id: bytes = None) -> List[dict]:
|
||||
def get_blinded_reply_paths(lnwallet: 'LNWallet', path_id: bytes, *,
|
||||
max_paths: int = REQUEST_REPLY_PATHS_MAX,
|
||||
preferred_node_id: bytes = None) -> List[dict]:
|
||||
# TODO: build longer paths and/or add dummy hops to increase privacy
|
||||
my_active_channels = [chan for chan in lnwallet.channels.values() if chan.is_active()]
|
||||
my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.peers.get(chan.node_id) and
|
||||
@@ -376,7 +385,9 @@ class OnionMessageManager(Logger):
|
||||
- association between onion message and their replies
|
||||
- manage re-send attempts, TODO: iterate through routes (both directions)"""
|
||||
|
||||
def __init__(self, lnwallet: 'LNWallet'):
|
||||
def __init__(self, lnwallet: 'LNWallet', *,
|
||||
request_reply_timeout=REQUEST_REPLY_TIMEOUT,
|
||||
request_reply_retry_delay=REQUEST_REPLY_RETRY_DELAY):
|
||||
Logger.__init__(self)
|
||||
self.network = None # type: Optional['Network']
|
||||
self.taskgroup = None # type: OldTaskGroup
|
||||
@@ -389,6 +400,9 @@ class OnionMessageManager(Logger):
|
||||
self.forwardqueue = queue.PriorityQueue()
|
||||
self.forwardqueue_notempty = asyncio.Event()
|
||||
|
||||
self.request_reply_timeout = request_reply_timeout
|
||||
self.request_reply_retry_delay = request_reply_retry_delay
|
||||
|
||||
def start_network(self, *, network: 'Network'):
|
||||
assert network
|
||||
assert self.network is None, "already started"
|
||||
@@ -415,13 +429,13 @@ class OnionMessageManager(Logger):
|
||||
try:
|
||||
scheduled, expires, onion_packet, blinding, node_id = self.forwardqueue.get_nowait()
|
||||
except queue.Empty:
|
||||
self.logger.debug(f'fwd queue empty')
|
||||
self.logger.info(f'forward queue empty')
|
||||
self.forwardqueue_notempty.clear()
|
||||
await self.forwardqueue_notempty.wait()
|
||||
continue
|
||||
|
||||
if expires <= now():
|
||||
self.logger.debug(f'fwd expired {node_id=}')
|
||||
self.logger.debug(f'forward expired {node_id=}')
|
||||
continue
|
||||
if scheduled > now():
|
||||
# return to queue
|
||||
@@ -448,7 +462,7 @@ class OnionMessageManager(Logger):
|
||||
blinding: bytes,
|
||||
node_id: bytes):
|
||||
if self.forwardqueue.qsize() >= FORWARD_MAX_QUEUE:
|
||||
self.logger.debug('fwd queue full, dropping packet')
|
||||
self.logger.debug('forward queue full, dropping packet')
|
||||
return
|
||||
expires = now() + FORWARD_RETRY_TIMEOUT
|
||||
queueitem = (now(), expires, onion_packet, blinding, node_id)
|
||||
@@ -460,9 +474,13 @@ class OnionMessageManager(Logger):
|
||||
try:
|
||||
scheduled, expires, key = self.requestreply_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
self.logger.debug(f'requestreply queue empty')
|
||||
self.logger.info(f'requestreply queue empty')
|
||||
self.requestreply_queue_notempty.clear()
|
||||
await self.requestreply_queue_notempty.wait()
|
||||
try:
|
||||
self.requestreply_queue_notempty.clear()
|
||||
await self.requestreply_queue_notempty.wait() # NOTE: quirk, see note below
|
||||
except Exception as e:
|
||||
self.logger.info(f'Exception e={e!r}')
|
||||
continue
|
||||
|
||||
requestreply = self.get_requestreply(key)
|
||||
@@ -483,12 +501,17 @@ class OnionMessageManager(Logger):
|
||||
continue
|
||||
|
||||
try:
|
||||
await self._send_pending_requestreply(key)
|
||||
self._send_pending_requestreply(key)
|
||||
except BaseException as e:
|
||||
self.logger.debug(f'error while sending {key=}')
|
||||
self._set_requestreply_result(key, e)
|
||||
self.logger.debug(f'error while sending {key=} {e!r}')
|
||||
self._set_requestreply_result(key, copy.copy(e))
|
||||
# NOTE: above, when passing the caught exception instance e directly it leads to GeneratorExit() in
|
||||
# queue_notempty.wait() later (??). pass a copy instead.
|
||||
if isinstance(e, NoRouteFound) and e.peer_address:
|
||||
await self.lnwallet.add_peer(str(e.peer_address))
|
||||
else:
|
||||
self.requestreply_queue.put_nowait((now() + REQUEST_REPLY_RETRY_DELAY, expires, key))
|
||||
self.logger.debug(f'resubmit {key=}')
|
||||
self.requestreply_queue.put_nowait((now() + self.request_reply_retry_delay, expires, key))
|
||||
|
||||
def get_requestreply(self, key):
|
||||
with self.pending_lock:
|
||||
@@ -498,6 +521,7 @@ class OnionMessageManager(Logger):
|
||||
with self.pending_lock:
|
||||
requestreply = self.pending.get(key)
|
||||
if requestreply is None:
|
||||
self.logger.error(f'requestreply with {key=} not found!')
|
||||
return
|
||||
self.pending[key]['result'] = result
|
||||
requestreply['ev'].set()
|
||||
@@ -537,7 +561,7 @@ class OnionMessageManager(Logger):
|
||||
}
|
||||
|
||||
# tuple = (when to process, when it expires, key)
|
||||
expires = now() + REQUEST_REPLY_TIMEOUT
|
||||
expires = now() + self.request_reply_timeout
|
||||
queueitem = (now(), expires, key)
|
||||
self.requestreply_queue.put_nowait(queueitem)
|
||||
task = asyncio.create_task(self._requestreply_task(key))
|
||||
@@ -560,12 +584,12 @@ class OnionMessageManager(Logger):
|
||||
assert requestreply
|
||||
result = requestreply.get('result')
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
raise result # raising in the task requires caller to explicitly extract exception.
|
||||
return result
|
||||
finally:
|
||||
self._remove_requestreply(key)
|
||||
|
||||
async def _send_pending_requestreply(self, key):
|
||||
def _send_pending_requestreply(self, key):
|
||||
"""adds reply_path to payload"""
|
||||
data = self.get_requestreply(key)
|
||||
payload = data.get('payload')
|
||||
@@ -577,7 +601,7 @@ class OnionMessageManager(Logger):
|
||||
if 'reply_path' not in final_payload:
|
||||
# unless explicitly set in payload, generate reply_path here
|
||||
path_id = self._path_id_from_payload_and_key(payload, key)
|
||||
reply_paths = await get_blinded_reply_paths(self.lnwallet, path_id, max_paths=1)
|
||||
reply_paths = get_blinded_reply_paths(self.lnwallet, path_id, max_paths=1)
|
||||
if not reply_paths:
|
||||
raise Exception(f'Could not create a reply_path for {key=}')
|
||||
|
||||
@@ -585,10 +609,9 @@ class OnionMessageManager(Logger):
|
||||
|
||||
# TODO: we should try alternate paths when retrying, this is currently not done.
|
||||
# (send_onion_message_to decides path, without knowledge of prev attempts)
|
||||
await send_onion_message_to(self.lnwallet, node_id_or_blinded_path, final_payload)
|
||||
send_onion_message_to(self.lnwallet, node_id_or_blinded_path, final_payload)
|
||||
|
||||
def _path_id_from_payload_and_key(self, payload: dict, key: bytes) -> bytes:
|
||||
# TODO: construct path_id in such a way that we can determine the request originated from us and is not spoofed
|
||||
# TODO: use payload to determine prefix?
|
||||
return b'electrum' + key
|
||||
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import asyncio
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import electrum_ecc as ecc
|
||||
from electrum_ecc import ECPrivkey
|
||||
|
||||
from electrum.lnmsg import decode_msg, OnionWireSerializer
|
||||
from electrum.lnonion import (
|
||||
@@ -10,10 +14,13 @@ from electrum.lnonion import (
|
||||
get_shared_secrets_along_route, new_onion_packet, ONION_MESSAGE_LARGE_SIZE,
|
||||
HOPS_DATA_SIZE, InvalidPayloadSize)
|
||||
from electrum.crypto import get_ecdh
|
||||
from electrum.onion_message import blinding_privkey, create_blinded_path, encrypt_onionmsg_tlv_hops_data
|
||||
from electrum.util import bfh, read_json_file
|
||||
from electrum.lnutil import LnFeatures
|
||||
from electrum.onion_message import blinding_privkey, create_blinded_path, encrypt_onionmsg_tlv_hops_data, \
|
||||
OnionMessageManager, NoRouteFound, Timeout
|
||||
from electrum.util import bfh, read_json_file, OldTaskGroup, get_asyncio_loop
|
||||
|
||||
from . import ElectrumTestCase
|
||||
from . import ElectrumTestCase, test_lnpeer
|
||||
from .test_lnpeer import PutIntoOthersQueueTransport, PeerInTests, keypair
|
||||
|
||||
# test vectors https://github.com/lightning/bolts/pull/759/files
|
||||
path = os.path.join(os.path.dirname(__file__), 'blinded-onion-message-onion-test.json')
|
||||
@@ -232,3 +239,149 @@ class TestOnionMessage(ElectrumTestCase):
|
||||
encrypt_onionmsg_tlv_hops_data(hops_data, hop_shared_secrets)
|
||||
packet = new_onion_packet(payment_path_pubkeys, SESSION_KEY, hops_data, onion_message=True)
|
||||
self.assertEqual(packet.to_bytes(), ONION_MESSAGE_PACKET)
|
||||
|
||||
|
||||
class MockNetwork:
|
||||
def __init__(self):
|
||||
self.asyncio_loop = get_asyncio_loop()
|
||||
self.taskgroup = OldTaskGroup()
|
||||
|
||||
|
||||
class MockWallet:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class MockLNWallet(test_lnpeer.MockLNWallet):
|
||||
|
||||
async def add_peer(self, connect_str: str):
|
||||
t1 = PutIntoOthersQueueTransport(self.node_keypair, 'test')
|
||||
p1 = PeerInTests(self, keypair().pubkey, t1)
|
||||
self.peers[p1.pubkey] = p1
|
||||
p1.initialized.set_result(True)
|
||||
return p1
|
||||
|
||||
|
||||
class MockPeer:
|
||||
their_features = LnFeatures(LnFeatures.OPTION_ONION_MESSAGE_OPT)
|
||||
|
||||
def __init__(self, pubkey, on_send_message=None):
|
||||
self.pubkey = pubkey
|
||||
self.on_send_message = on_send_message
|
||||
|
||||
async def wait_one_htlc_switch_iteration(self, *args):
|
||||
pass
|
||||
|
||||
def send_message(self, *args, **kwargs):
|
||||
if self.on_send_message:
|
||||
self.on_send_message(*args, **kwargs)
|
||||
|
||||
|
||||
class TestOnionMessageManager(ElectrumTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.alice = ECPrivkey(privkey_bytes=b'\x41'*32)
|
||||
self.alice_pub = self.alice.get_public_key_bytes(compressed=True)
|
||||
self.bob = ECPrivkey(privkey_bytes=b'\x42'*32)
|
||||
self.bob_pub = self.bob.get_public_key_bytes(compressed=True)
|
||||
self.carol = ECPrivkey(privkey_bytes=b'\x43'*32)
|
||||
self.carol_pub = self.carol.get_public_key_bytes(compressed=True)
|
||||
self.dave = ECPrivkey(privkey_bytes=b'\x44'*32)
|
||||
self.dave_pub = self.dave.get_public_key_bytes(compressed=True)
|
||||
self.eve = ECPrivkey(privkey_bytes=b'\x45'*32)
|
||||
self.eve_pub = self.eve.get_public_key_bytes(compressed=True)
|
||||
|
||||
async def run_test1(self, t):
|
||||
t1 = t.submit_requestreply(
|
||||
payload={'message': {'text': 'alice_timeout'.encode('utf-8')}},
|
||||
node_id_or_blinded_path=self.alice_pub)
|
||||
|
||||
with self.assertRaises(Timeout):
|
||||
await t1
|
||||
|
||||
async def run_test2(self, t):
|
||||
t2 = t.submit_requestreply(
|
||||
payload={'message': {'text': 'bob_slow_timeout'.encode('utf-8')}},
|
||||
node_id_or_blinded_path=self.bob_pub)
|
||||
|
||||
with self.assertRaises(Timeout):
|
||||
await t2
|
||||
|
||||
async def run_test3(self, t, rkey):
|
||||
t3 = t.submit_requestreply(
|
||||
payload={'message': {'text': 'carol_with_immediate_reply'.encode('utf-8')}},
|
||||
node_id_or_blinded_path=self.carol_pub,
|
||||
key=rkey)
|
||||
|
||||
t3_result = await t3
|
||||
self.assertEqual(t3_result, ({'path_id': {'data': b'electrum' + rkey}}, {}))
|
||||
|
||||
async def run_test4(self, t, rkey):
|
||||
t4 = t.submit_requestreply(
|
||||
payload={'message': {'text': 'dave_with_slow_reply'.encode('utf-8')}},
|
||||
node_id_or_blinded_path=self.dave_pub,
|
||||
key=rkey)
|
||||
|
||||
t4_result = await t4
|
||||
self.assertEqual(t4_result, ({'path_id': {'data': b'electrum' + rkey}}, {}))
|
||||
|
||||
async def run_test5(self, t):
|
||||
t5 = t.submit_requestreply(
|
||||
payload={'message': {'text': 'no_peer'.encode('utf-8')}},
|
||||
node_id_or_blinded_path=self.eve_pub)
|
||||
|
||||
with self.assertRaises(NoRouteFound):
|
||||
await t5
|
||||
|
||||
async def test_manager(self):
|
||||
n = MockNetwork()
|
||||
k = keypair()
|
||||
q1, q2 = asyncio.Queue(), asyncio.Queue()
|
||||
lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test', has_anchors=False)
|
||||
|
||||
def slow(*args, **kwargs):
|
||||
time.sleep(2)
|
||||
|
||||
def withreply(key, *args, **kwargs):
|
||||
t.on_onion_message_received_reply({'path_id': {'data': b'electrum' + key}}, {})
|
||||
|
||||
def slowwithreply(key, *args, **kwargs):
|
||||
time.sleep(2)
|
||||
t.on_onion_message_received_reply({'path_id': {'data': b'electrum' + key}}, {})
|
||||
|
||||
rkey1 = bfh('0102030405060708')
|
||||
rkey2 = bfh('0102030405060709')
|
||||
|
||||
lnw.peers[self.alice_pub] = MockPeer(self.alice_pub)
|
||||
lnw.peers[self.bob_pub] = MockPeer(self.bob_pub, on_send_message=slow)
|
||||
lnw.peers[self.carol_pub] = MockPeer(self.carol_pub, on_send_message=partial(withreply, rkey1))
|
||||
lnw.peers[self.dave_pub] = MockPeer(self.dave_pub, on_send_message=partial(slowwithreply, rkey2))
|
||||
t = OnionMessageManager(lnw, request_reply_timeout=5, request_reply_retry_delay=1)
|
||||
t.start_network(network=n)
|
||||
|
||||
try:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
self.logger.debug('tests in sequence')
|
||||
|
||||
await self.run_test1(t)
|
||||
await self.run_test2(t)
|
||||
await self.run_test3(t, rkey1)
|
||||
await self.run_test4(t, rkey2)
|
||||
await self.run_test5(t)
|
||||
|
||||
self.logger.debug('tests in parallel')
|
||||
|
||||
async with OldTaskGroup() as group:
|
||||
await group.spawn(self.run_test1(t))
|
||||
await group.spawn(self.run_test2(t))
|
||||
await group.spawn(self.run_test3(t, rkey1))
|
||||
await group.spawn(self.run_test4(t, rkey2))
|
||||
await group.spawn(self.run_test5(t))
|
||||
finally:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
self.logger.debug('stopping manager')
|
||||
await t.stop()
|
||||
await lnw.stop()
|
||||
|
||||
Reference in New Issue
Block a user