diff --git a/electrum/onion_message.py b/electrum/onion_message.py index bb9025713..ed83e5b92 100644 --- a/electrum/onion_message.py +++ b/electrum/onion_message.py @@ -20,6 +20,7 @@ # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. + import asyncio import copy import io @@ -43,7 +44,6 @@ from electrum.lnutil import LnFeatures from electrum.util import OldTaskGroup, log_exceptions -# do not use util.now, because it rounds to integers def now(): return time.time() @@ -411,12 +411,6 @@ def get_blinded_reply_paths( class Timeout(Exception): pass -class OnionMessageRequest(NamedTuple): - future: asyncio.Future - payload: bytes - node_id_or_blinded_path: bytes - - class OnionMessageManager(Logger): """handle state around onion message sends and receives. - one instance per (ln)wallet @@ -437,12 +431,17 @@ class OnionMessageManager(Logger): FORWARD_RETRY_DELAY = 2 FORWARD_MAX_QUEUE = 3 + class Request(NamedTuple): + future: asyncio.Future + payload: dict + node_id_or_blinded_path: bytes + def __init__(self, lnwallet: 'LNWallet'): Logger.__init__(self) self.network = None # type: Optional['Network'] self.taskgroup = None # type: OldTaskGroup self.lnwallet = lnwallet - self.pending = {} # type: dict[bytes, OnionMessageRequest] + self.pending = {} # type: dict[bytes, OnionMessageManager.Request] self.pending_lock = threading.Lock() self.send_queue = asyncio.PriorityQueue() self.forward_queue = asyncio.PriorityQueue() @@ -559,7 +558,7 @@ class OnionMessageManager(Logger): self.logger.debug(f'submit_send {key=} {payload=} {node_id_or_blinded_path=}') - req = OnionMessageRequest( + req = OnionMessageManager.Request( future=asyncio.Future(), payload=payload, node_id_or_blinded_path=node_id_or_blinded_path @@ -608,6 +607,19 @@ class OnionMessageManager(Logger): # TODO: use payload to determine prefix? return b'electrum' + key + def _get_request_for_path_id(self, recipient_data: dict) -> Request: + path_id = recipient_data.get('path_id', {}).get('data') + if not path_id: + return None + if not path_id[:8] == b'electrum': + self.logger.warning('not a reply to our request (unknown path_id prefix)') + return None + key = path_id[8:] + req = self.pending.get(key) + if req is None: + self.logger.warning('not a reply to our request (unknown request)') + return req + def on_onion_message_received(self, recipient_data: dict, payload: dict) -> None: # we are destination, sanity checks # - if `encrypted_data_tlv` contains `allowed_features`: @@ -622,25 +634,16 @@ class OnionMessageManager(Logger): # - if `path_id` is set and corresponds to a path the reader has previously published in a `reply_path`: # - if the onion message is not a reply to that previous onion: # - MUST ignore the onion message - # TODO: store path_id and lookup here - if 'path_id' not in recipient_data: + req = self._get_request_for_path_id(recipient_data) + if req is None: # unsolicited onion_message self.on_onion_message_received_unsolicited(recipient_data, payload) else: - self.on_onion_message_received_reply(recipient_data, payload) + self.on_onion_message_received_reply(req, recipient_data, payload) - def on_onion_message_received_reply(self, recipient_data: dict, payload: dict) -> None: - # check if this reply is associated with a known request - correl_data = recipient_data['path_id'].get('data') - if not correl_data[:8] == b'electrum': - self.logger.warning('not a reply to our request (unknown path_id prefix)') - return - key = correl_data[8:] - req = self.pending.get(key) - if req is None: - self.logger.warning('not a reply to our request (unknown request)') - return - req.future.set_result((recipient_data, payload)) + def on_onion_message_received_reply(self, request: Request, recipient_data: dict, payload: dict) -> None: + assert request is not None, 'Request is mandatory' + request.future.set_result((recipient_data, payload)) def on_onion_message_received_unsolicited(self, recipient_data: dict, payload: dict) -> None: self.logger.debug('unsolicited onion_message received') diff --git a/tests/test_onion_message.py b/tests/test_onion_message.py index 295f55401..a65a90be8 100644 --- a/tests/test_onion_message.py +++ b/tests/test_onion_message.py @@ -14,8 +14,8 @@ from electrum.lnonion import ( process_onion_packet, get_bolt04_onion_key, encrypt_onionmsg_data_tlv, get_shared_secrets_along_route, new_onion_packet, ONION_MESSAGE_LARGE_SIZE, HOPS_DATA_SIZE, InvalidPayloadSize) -from electrum.crypto import get_ecdh -from electrum.lnutil import LnFeatures +from electrum.crypto import get_ecdh, privkey_to_pubkey +from electrum.lnutil import LnFeatures, Keypair from electrum.onion_message import blinding_privkey, create_blinded_path, encrypt_onionmsg_tlv_hops_data, \ OnionMessageManager, NoRouteFound, Timeout from electrum.util import bfh, read_json_file, OldTaskGroup, get_asyncio_loop @@ -304,21 +304,21 @@ class TestOnionMessageManager(ElectrumTestCase): def setUp(self): super().setUp() - self.alice = ECPrivkey(privkey_bytes=b'\x41'*32) - self.alice_pub = self.alice.get_public_key_bytes(compressed=True) - self.bob = ECPrivkey(privkey_bytes=b'\x42'*32) - self.bob_pub = self.bob.get_public_key_bytes(compressed=True) - self.carol = ECPrivkey(privkey_bytes=b'\x43'*32) - self.carol_pub = self.carol.get_public_key_bytes(compressed=True) - self.dave = ECPrivkey(privkey_bytes=b'\x44'*32) - self.dave_pub = self.dave.get_public_key_bytes(compressed=True) - self.eve = ECPrivkey(privkey_bytes=b'\x45'*32) - self.eve_pub = self.eve.get_public_key_bytes(compressed=True) + + def keypair(privkey: ECPrivkey): + priv = privkey.get_secret_bytes() + return Keypair(pubkey=privkey_to_pubkey(priv), privkey=priv) + + self.alice = keypair(ECPrivkey(privkey_bytes=b'\x41'*32)) + self.bob = keypair(ECPrivkey(privkey_bytes=b'\x42'*32)) + self.carol = keypair(ECPrivkey(privkey_bytes=b'\x43'*32)) + self.dave = keypair(ECPrivkey(privkey_bytes=b'\x44'*32)) + self.eve = keypair(ECPrivkey(privkey_bytes=b'\x45'*32)) async def run_test1(self, t): t1 = t.submit_send( payload={'message': {'text': 'alice_timeout'.encode('utf-8')}}, - node_id_or_blinded_path=self.alice_pub) + node_id_or_blinded_path=self.alice.pubkey) with self.assertRaises(Timeout): await t1 @@ -326,7 +326,7 @@ class TestOnionMessageManager(ElectrumTestCase): async def run_test2(self, t): t2 = t.submit_send( payload={'message': {'text': 'bob_slow_timeout'.encode('utf-8')}}, - node_id_or_blinded_path=self.bob_pub) + node_id_or_blinded_path=self.bob.pubkey) with self.assertRaises(Timeout): await t2 @@ -334,7 +334,7 @@ class TestOnionMessageManager(ElectrumTestCase): async def run_test3(self, t, rkey): t3 = t.submit_send( payload={'message': {'text': 'carol_with_immediate_reply'.encode('utf-8')}}, - node_id_or_blinded_path=self.carol_pub, + node_id_or_blinded_path=self.carol.pubkey, key=rkey) t3_result = await t3 @@ -343,7 +343,7 @@ class TestOnionMessageManager(ElectrumTestCase): async def run_test4(self, t, rkey): t4 = t.submit_send( payload={'message': {'text': 'dave_with_slow_reply'.encode('utf-8')}}, - node_id_or_blinded_path=self.dave_pub, + node_id_or_blinded_path=self.dave.pubkey, key=rkey) t4_result = await t4 @@ -352,34 +352,34 @@ class TestOnionMessageManager(ElectrumTestCase): async def run_test5(self, t): t5 = t.submit_send( payload={'message': {'text': 'no_peer'.encode('utf-8')}}, - node_id_or_blinded_path=self.eve_pub) + node_id_or_blinded_path=self.eve.pubkey) with self.assertRaises(NoRouteFound): await t5 - async def test_manager(self): + async def test_request_and_reply(self): n = MockNetwork() k = keypair() q1, q2 = asyncio.Queue(), asyncio.Queue() - lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test', has_anchors=False) + lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test_request_and_reply', has_anchors=False) def slow(*args, **kwargs): time.sleep(2*TIME_STEP) def withreply(key, *args, **kwargs): - t.on_onion_message_received_reply({'path_id': {'data': b'electrum' + key}}, {}) + t.on_onion_message_received({'path_id': {'data': b'electrum' + key}}, {}) def slowwithreply(key, *args, **kwargs): time.sleep(2*TIME_STEP) - t.on_onion_message_received_reply({'path_id': {'data': b'electrum' + key}}, {}) + t.on_onion_message_received({'path_id': {'data': b'electrum' + key}}, {}) rkey1 = bfh('0102030405060708') rkey2 = bfh('0102030405060709') - lnw.peers[self.alice_pub] = MockPeer(self.alice_pub) - lnw.peers[self.bob_pub] = MockPeer(self.bob_pub, on_send_message=slow) - lnw.peers[self.carol_pub] = MockPeer(self.carol_pub, on_send_message=partial(withreply, rkey1)) - lnw.peers[self.dave_pub] = MockPeer(self.dave_pub, on_send_message=partial(slowwithreply, rkey2)) + lnw.peers[self.alice.pubkey] = MockPeer(self.alice.pubkey) + lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=slow) + lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(withreply, rkey1)) + lnw.peers[self.dave.pubkey] = MockPeer(self.dave.pubkey, on_send_message=partial(slowwithreply, rkey2)) t = OnionMessageManager(lnw) t.start_network(network=n) @@ -404,3 +404,66 @@ class TestOnionMessageManager(ElectrumTestCase): self.logger.debug('stopping manager') await t.stop() await lnw.stop() + + async def test_forward(self): + n = MockNetwork() + q1 = asyncio.Queue() + lnw = MockLNWallet(local_keypair=self.alice, chans=[], tx_queue=q1, name='alice', has_anchors=False) + + self.was_sent = False + + def on_send(to: str, *args, **kwargs): + self.assertEqual(to, 'bob') + self.was_sent = True + # validate what's sent to bob + self.assertEqual(bfh(HOPS[1]['E']), kwargs['blinding']) + message_type, payload = decode_msg(bfh(test_vectors['decrypt']['hops'][1]['onion_message'])) + self.assertEqual(message_type, 'onion_message') + self.assertEqual(payload['onion_message_packet'], kwargs['onion_message_packet']) + + lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=partial(on_send, 'bob')) + lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(on_send, 'carol')) + t = OnionMessageManager(lnw) + t.start_network(network=n) + + onionmsg = bfh(test_vectors['onionmessage']['onion_message_packet']) + try: + t.on_onion_message({ + 'blinding': bfh(test_vectors['route']['blinding']), + 'len': len(onionmsg), + 'onion_message_packet': onionmsg + }) + finally: + await asyncio.sleep(TIME_STEP) + + self.logger.debug('stopping manager') + await t.stop() + await lnw.stop() + + self.assertTrue(self.was_sent) + + async def test_receive_unsolicited(self): + n = MockNetwork() + q1 = asyncio.Queue() + lnw = MockLNWallet(local_keypair=self.dave, chans=[], tx_queue=q1, name='dave', has_anchors=False) + + t = OnionMessageManager(lnw) + t.start_network(network=n) + + self.received_unsolicited = False + + def my_on_onion_message_received_unsolicited(*args, **kwargs): + self.received_unsolicited = True + + t.on_onion_message_received_unsolicited = my_on_onion_message_received_unsolicited + packet = bfh(test_vectors['decrypt']['hops'][3]['onion_message']) + message_type, payload = decode_msg(packet) + try: + t.on_onion_message(payload) + self.assertTrue(self.received_unsolicited) + finally: + await asyncio.sleep(TIME_STEP) + + self.logger.debug('stopping manager') + await t.stop() + await lnw.stop()