1
0

onion_messages: add tests for forwards, receive unsolicited.

This commit is contained in:
Sander van Grieken
2025-02-19 16:43:08 +01:00
parent 0b86e39121
commit c3c5aaab3d
2 changed files with 115 additions and 49 deletions

View File

@@ -20,6 +20,7 @@
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import asyncio
import copy
import io
@@ -43,7 +44,6 @@ from electrum.lnutil import LnFeatures
from electrum.util import OldTaskGroup, log_exceptions
# do not use util.now, because it rounds to integers
def now():
return time.time()
@@ -411,12 +411,6 @@ def get_blinded_reply_paths(
class Timeout(Exception): pass
class OnionMessageRequest(NamedTuple):
future: asyncio.Future
payload: bytes
node_id_or_blinded_path: bytes
class OnionMessageManager(Logger):
"""handle state around onion message sends and receives.
- one instance per (ln)wallet
@@ -437,12 +431,17 @@ class OnionMessageManager(Logger):
FORWARD_RETRY_DELAY = 2
FORWARD_MAX_QUEUE = 3
class Request(NamedTuple):
future: asyncio.Future
payload: dict
node_id_or_blinded_path: bytes
def __init__(self, lnwallet: 'LNWallet'):
Logger.__init__(self)
self.network = None # type: Optional['Network']
self.taskgroup = None # type: OldTaskGroup
self.lnwallet = lnwallet
self.pending = {} # type: dict[bytes, OnionMessageRequest]
self.pending = {} # type: dict[bytes, OnionMessageManager.Request]
self.pending_lock = threading.Lock()
self.send_queue = asyncio.PriorityQueue()
self.forward_queue = asyncio.PriorityQueue()
@@ -559,7 +558,7 @@ class OnionMessageManager(Logger):
self.logger.debug(f'submit_send {key=} {payload=} {node_id_or_blinded_path=}')
req = OnionMessageRequest(
req = OnionMessageManager.Request(
future=asyncio.Future(),
payload=payload,
node_id_or_blinded_path=node_id_or_blinded_path
@@ -608,6 +607,19 @@ class OnionMessageManager(Logger):
# TODO: use payload to determine prefix?
return b'electrum' + key
def _get_request_for_path_id(self, recipient_data: dict) -> Request:
path_id = recipient_data.get('path_id', {}).get('data')
if not path_id:
return None
if not path_id[:8] == b'electrum':
self.logger.warning('not a reply to our request (unknown path_id prefix)')
return None
key = path_id[8:]
req = self.pending.get(key)
if req is None:
self.logger.warning('not a reply to our request (unknown request)')
return req
def on_onion_message_received(self, recipient_data: dict, payload: dict) -> None:
# we are destination, sanity checks
# - if `encrypted_data_tlv` contains `allowed_features`:
@@ -622,25 +634,16 @@ class OnionMessageManager(Logger):
# - if `path_id` is set and corresponds to a path the reader has previously published in a `reply_path`:
# - if the onion message is not a reply to that previous onion:
# - MUST ignore the onion message
# TODO: store path_id and lookup here
if 'path_id' not in recipient_data:
req = self._get_request_for_path_id(recipient_data)
if req is None:
# unsolicited onion_message
self.on_onion_message_received_unsolicited(recipient_data, payload)
else:
self.on_onion_message_received_reply(recipient_data, payload)
self.on_onion_message_received_reply(req, recipient_data, payload)
def on_onion_message_received_reply(self, recipient_data: dict, payload: dict) -> None:
# check if this reply is associated with a known request
correl_data = recipient_data['path_id'].get('data')
if not correl_data[:8] == b'electrum':
self.logger.warning('not a reply to our request (unknown path_id prefix)')
return
key = correl_data[8:]
req = self.pending.get(key)
if req is None:
self.logger.warning('not a reply to our request (unknown request)')
return
req.future.set_result((recipient_data, payload))
def on_onion_message_received_reply(self, request: Request, recipient_data: dict, payload: dict) -> None:
assert request is not None, 'Request is mandatory'
request.future.set_result((recipient_data, payload))
def on_onion_message_received_unsolicited(self, recipient_data: dict, payload: dict) -> None:
self.logger.debug('unsolicited onion_message received')

View File

@@ -14,8 +14,8 @@ from electrum.lnonion import (
process_onion_packet, get_bolt04_onion_key, encrypt_onionmsg_data_tlv,
get_shared_secrets_along_route, new_onion_packet, ONION_MESSAGE_LARGE_SIZE,
HOPS_DATA_SIZE, InvalidPayloadSize)
from electrum.crypto import get_ecdh
from electrum.lnutil import LnFeatures
from electrum.crypto import get_ecdh, privkey_to_pubkey
from electrum.lnutil import LnFeatures, Keypair
from electrum.onion_message import blinding_privkey, create_blinded_path, encrypt_onionmsg_tlv_hops_data, \
OnionMessageManager, NoRouteFound, Timeout
from electrum.util import bfh, read_json_file, OldTaskGroup, get_asyncio_loop
@@ -304,21 +304,21 @@ class TestOnionMessageManager(ElectrumTestCase):
def setUp(self):
super().setUp()
self.alice = ECPrivkey(privkey_bytes=b'\x41'*32)
self.alice_pub = self.alice.get_public_key_bytes(compressed=True)
self.bob = ECPrivkey(privkey_bytes=b'\x42'*32)
self.bob_pub = self.bob.get_public_key_bytes(compressed=True)
self.carol = ECPrivkey(privkey_bytes=b'\x43'*32)
self.carol_pub = self.carol.get_public_key_bytes(compressed=True)
self.dave = ECPrivkey(privkey_bytes=b'\x44'*32)
self.dave_pub = self.dave.get_public_key_bytes(compressed=True)
self.eve = ECPrivkey(privkey_bytes=b'\x45'*32)
self.eve_pub = self.eve.get_public_key_bytes(compressed=True)
def keypair(privkey: ECPrivkey):
priv = privkey.get_secret_bytes()
return Keypair(pubkey=privkey_to_pubkey(priv), privkey=priv)
self.alice = keypair(ECPrivkey(privkey_bytes=b'\x41'*32))
self.bob = keypair(ECPrivkey(privkey_bytes=b'\x42'*32))
self.carol = keypair(ECPrivkey(privkey_bytes=b'\x43'*32))
self.dave = keypair(ECPrivkey(privkey_bytes=b'\x44'*32))
self.eve = keypair(ECPrivkey(privkey_bytes=b'\x45'*32))
async def run_test1(self, t):
t1 = t.submit_send(
payload={'message': {'text': 'alice_timeout'.encode('utf-8')}},
node_id_or_blinded_path=self.alice_pub)
node_id_or_blinded_path=self.alice.pubkey)
with self.assertRaises(Timeout):
await t1
@@ -326,7 +326,7 @@ class TestOnionMessageManager(ElectrumTestCase):
async def run_test2(self, t):
t2 = t.submit_send(
payload={'message': {'text': 'bob_slow_timeout'.encode('utf-8')}},
node_id_or_blinded_path=self.bob_pub)
node_id_or_blinded_path=self.bob.pubkey)
with self.assertRaises(Timeout):
await t2
@@ -334,7 +334,7 @@ class TestOnionMessageManager(ElectrumTestCase):
async def run_test3(self, t, rkey):
t3 = t.submit_send(
payload={'message': {'text': 'carol_with_immediate_reply'.encode('utf-8')}},
node_id_or_blinded_path=self.carol_pub,
node_id_or_blinded_path=self.carol.pubkey,
key=rkey)
t3_result = await t3
@@ -343,7 +343,7 @@ class TestOnionMessageManager(ElectrumTestCase):
async def run_test4(self, t, rkey):
t4 = t.submit_send(
payload={'message': {'text': 'dave_with_slow_reply'.encode('utf-8')}},
node_id_or_blinded_path=self.dave_pub,
node_id_or_blinded_path=self.dave.pubkey,
key=rkey)
t4_result = await t4
@@ -352,34 +352,34 @@ class TestOnionMessageManager(ElectrumTestCase):
async def run_test5(self, t):
t5 = t.submit_send(
payload={'message': {'text': 'no_peer'.encode('utf-8')}},
node_id_or_blinded_path=self.eve_pub)
node_id_or_blinded_path=self.eve.pubkey)
with self.assertRaises(NoRouteFound):
await t5
async def test_manager(self):
async def test_request_and_reply(self):
n = MockNetwork()
k = keypair()
q1, q2 = asyncio.Queue(), asyncio.Queue()
lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test', has_anchors=False)
lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test_request_and_reply', has_anchors=False)
def slow(*args, **kwargs):
time.sleep(2*TIME_STEP)
def withreply(key, *args, **kwargs):
t.on_onion_message_received_reply({'path_id': {'data': b'electrum' + key}}, {})
t.on_onion_message_received({'path_id': {'data': b'electrum' + key}}, {})
def slowwithreply(key, *args, **kwargs):
time.sleep(2*TIME_STEP)
t.on_onion_message_received_reply({'path_id': {'data': b'electrum' + key}}, {})
t.on_onion_message_received({'path_id': {'data': b'electrum' + key}}, {})
rkey1 = bfh('0102030405060708')
rkey2 = bfh('0102030405060709')
lnw.peers[self.alice_pub] = MockPeer(self.alice_pub)
lnw.peers[self.bob_pub] = MockPeer(self.bob_pub, on_send_message=slow)
lnw.peers[self.carol_pub] = MockPeer(self.carol_pub, on_send_message=partial(withreply, rkey1))
lnw.peers[self.dave_pub] = MockPeer(self.dave_pub, on_send_message=partial(slowwithreply, rkey2))
lnw.peers[self.alice.pubkey] = MockPeer(self.alice.pubkey)
lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=slow)
lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(withreply, rkey1))
lnw.peers[self.dave.pubkey] = MockPeer(self.dave.pubkey, on_send_message=partial(slowwithreply, rkey2))
t = OnionMessageManager(lnw)
t.start_network(network=n)
@@ -404,3 +404,66 @@ class TestOnionMessageManager(ElectrumTestCase):
self.logger.debug('stopping manager')
await t.stop()
await lnw.stop()
async def test_forward(self):
n = MockNetwork()
q1 = asyncio.Queue()
lnw = MockLNWallet(local_keypair=self.alice, chans=[], tx_queue=q1, name='alice', has_anchors=False)
self.was_sent = False
def on_send(to: str, *args, **kwargs):
self.assertEqual(to, 'bob')
self.was_sent = True
# validate what's sent to bob
self.assertEqual(bfh(HOPS[1]['E']), kwargs['blinding'])
message_type, payload = decode_msg(bfh(test_vectors['decrypt']['hops'][1]['onion_message']))
self.assertEqual(message_type, 'onion_message')
self.assertEqual(payload['onion_message_packet'], kwargs['onion_message_packet'])
lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=partial(on_send, 'bob'))
lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(on_send, 'carol'))
t = OnionMessageManager(lnw)
t.start_network(network=n)
onionmsg = bfh(test_vectors['onionmessage']['onion_message_packet'])
try:
t.on_onion_message({
'blinding': bfh(test_vectors['route']['blinding']),
'len': len(onionmsg),
'onion_message_packet': onionmsg
})
finally:
await asyncio.sleep(TIME_STEP)
self.logger.debug('stopping manager')
await t.stop()
await lnw.stop()
self.assertTrue(self.was_sent)
async def test_receive_unsolicited(self):
n = MockNetwork()
q1 = asyncio.Queue()
lnw = MockLNWallet(local_keypair=self.dave, chans=[], tx_queue=q1, name='dave', has_anchors=False)
t = OnionMessageManager(lnw)
t.start_network(network=n)
self.received_unsolicited = False
def my_on_onion_message_received_unsolicited(*args, **kwargs):
self.received_unsolicited = True
t.on_onion_message_received_unsolicited = my_on_onion_message_received_unsolicited
packet = bfh(test_vectors['decrypt']['hops'][3]['onion_message'])
message_type, payload = decode_msg(packet)
try:
t.on_onion_message(payload)
self.assertTrue(self.received_unsolicited)
finally:
await asyncio.sleep(TIME_STEP)
self.logger.debug('stopping manager')
await t.stop()
await lnw.stop()