onion_messages: add tests for forwards, receive unsolicited.
This commit is contained in:
@@ -20,6 +20,7 @@
|
||||
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import io
|
||||
@@ -43,7 +44,6 @@ from electrum.lnutil import LnFeatures
|
||||
from electrum.util import OldTaskGroup, log_exceptions
|
||||
|
||||
|
||||
# do not use util.now, because it rounds to integers
|
||||
def now():
|
||||
return time.time()
|
||||
|
||||
@@ -411,12 +411,6 @@ def get_blinded_reply_paths(
|
||||
class Timeout(Exception): pass
|
||||
|
||||
|
||||
class OnionMessageRequest(NamedTuple):
|
||||
future: asyncio.Future
|
||||
payload: bytes
|
||||
node_id_or_blinded_path: bytes
|
||||
|
||||
|
||||
class OnionMessageManager(Logger):
|
||||
"""handle state around onion message sends and receives.
|
||||
- one instance per (ln)wallet
|
||||
@@ -437,12 +431,17 @@ class OnionMessageManager(Logger):
|
||||
FORWARD_RETRY_DELAY = 2
|
||||
FORWARD_MAX_QUEUE = 3
|
||||
|
||||
class Request(NamedTuple):
|
||||
future: asyncio.Future
|
||||
payload: dict
|
||||
node_id_or_blinded_path: bytes
|
||||
|
||||
def __init__(self, lnwallet: 'LNWallet'):
|
||||
Logger.__init__(self)
|
||||
self.network = None # type: Optional['Network']
|
||||
self.taskgroup = None # type: OldTaskGroup
|
||||
self.lnwallet = lnwallet
|
||||
self.pending = {} # type: dict[bytes, OnionMessageRequest]
|
||||
self.pending = {} # type: dict[bytes, OnionMessageManager.Request]
|
||||
self.pending_lock = threading.Lock()
|
||||
self.send_queue = asyncio.PriorityQueue()
|
||||
self.forward_queue = asyncio.PriorityQueue()
|
||||
@@ -559,7 +558,7 @@ class OnionMessageManager(Logger):
|
||||
|
||||
self.logger.debug(f'submit_send {key=} {payload=} {node_id_or_blinded_path=}')
|
||||
|
||||
req = OnionMessageRequest(
|
||||
req = OnionMessageManager.Request(
|
||||
future=asyncio.Future(),
|
||||
payload=payload,
|
||||
node_id_or_blinded_path=node_id_or_blinded_path
|
||||
@@ -608,6 +607,19 @@ class OnionMessageManager(Logger):
|
||||
# TODO: use payload to determine prefix?
|
||||
return b'electrum' + key
|
||||
|
||||
def _get_request_for_path_id(self, recipient_data: dict) -> Request:
|
||||
path_id = recipient_data.get('path_id', {}).get('data')
|
||||
if not path_id:
|
||||
return None
|
||||
if not path_id[:8] == b'electrum':
|
||||
self.logger.warning('not a reply to our request (unknown path_id prefix)')
|
||||
return None
|
||||
key = path_id[8:]
|
||||
req = self.pending.get(key)
|
||||
if req is None:
|
||||
self.logger.warning('not a reply to our request (unknown request)')
|
||||
return req
|
||||
|
||||
def on_onion_message_received(self, recipient_data: dict, payload: dict) -> None:
|
||||
# we are destination, sanity checks
|
||||
# - if `encrypted_data_tlv` contains `allowed_features`:
|
||||
@@ -622,25 +634,16 @@ class OnionMessageManager(Logger):
|
||||
# - if `path_id` is set and corresponds to a path the reader has previously published in a `reply_path`:
|
||||
# - if the onion message is not a reply to that previous onion:
|
||||
# - MUST ignore the onion message
|
||||
# TODO: store path_id and lookup here
|
||||
if 'path_id' not in recipient_data:
|
||||
req = self._get_request_for_path_id(recipient_data)
|
||||
if req is None:
|
||||
# unsolicited onion_message
|
||||
self.on_onion_message_received_unsolicited(recipient_data, payload)
|
||||
else:
|
||||
self.on_onion_message_received_reply(recipient_data, payload)
|
||||
self.on_onion_message_received_reply(req, recipient_data, payload)
|
||||
|
||||
def on_onion_message_received_reply(self, recipient_data: dict, payload: dict) -> None:
|
||||
# check if this reply is associated with a known request
|
||||
correl_data = recipient_data['path_id'].get('data')
|
||||
if not correl_data[:8] == b'electrum':
|
||||
self.logger.warning('not a reply to our request (unknown path_id prefix)')
|
||||
return
|
||||
key = correl_data[8:]
|
||||
req = self.pending.get(key)
|
||||
if req is None:
|
||||
self.logger.warning('not a reply to our request (unknown request)')
|
||||
return
|
||||
req.future.set_result((recipient_data, payload))
|
||||
def on_onion_message_received_reply(self, request: Request, recipient_data: dict, payload: dict) -> None:
|
||||
assert request is not None, 'Request is mandatory'
|
||||
request.future.set_result((recipient_data, payload))
|
||||
|
||||
def on_onion_message_received_unsolicited(self, recipient_data: dict, payload: dict) -> None:
|
||||
self.logger.debug('unsolicited onion_message received')
|
||||
|
||||
@@ -14,8 +14,8 @@ from electrum.lnonion import (
|
||||
process_onion_packet, get_bolt04_onion_key, encrypt_onionmsg_data_tlv,
|
||||
get_shared_secrets_along_route, new_onion_packet, ONION_MESSAGE_LARGE_SIZE,
|
||||
HOPS_DATA_SIZE, InvalidPayloadSize)
|
||||
from electrum.crypto import get_ecdh
|
||||
from electrum.lnutil import LnFeatures
|
||||
from electrum.crypto import get_ecdh, privkey_to_pubkey
|
||||
from electrum.lnutil import LnFeatures, Keypair
|
||||
from electrum.onion_message import blinding_privkey, create_blinded_path, encrypt_onionmsg_tlv_hops_data, \
|
||||
OnionMessageManager, NoRouteFound, Timeout
|
||||
from electrum.util import bfh, read_json_file, OldTaskGroup, get_asyncio_loop
|
||||
@@ -304,21 +304,21 @@ class TestOnionMessageManager(ElectrumTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.alice = ECPrivkey(privkey_bytes=b'\x41'*32)
|
||||
self.alice_pub = self.alice.get_public_key_bytes(compressed=True)
|
||||
self.bob = ECPrivkey(privkey_bytes=b'\x42'*32)
|
||||
self.bob_pub = self.bob.get_public_key_bytes(compressed=True)
|
||||
self.carol = ECPrivkey(privkey_bytes=b'\x43'*32)
|
||||
self.carol_pub = self.carol.get_public_key_bytes(compressed=True)
|
||||
self.dave = ECPrivkey(privkey_bytes=b'\x44'*32)
|
||||
self.dave_pub = self.dave.get_public_key_bytes(compressed=True)
|
||||
self.eve = ECPrivkey(privkey_bytes=b'\x45'*32)
|
||||
self.eve_pub = self.eve.get_public_key_bytes(compressed=True)
|
||||
|
||||
def keypair(privkey: ECPrivkey):
|
||||
priv = privkey.get_secret_bytes()
|
||||
return Keypair(pubkey=privkey_to_pubkey(priv), privkey=priv)
|
||||
|
||||
self.alice = keypair(ECPrivkey(privkey_bytes=b'\x41'*32))
|
||||
self.bob = keypair(ECPrivkey(privkey_bytes=b'\x42'*32))
|
||||
self.carol = keypair(ECPrivkey(privkey_bytes=b'\x43'*32))
|
||||
self.dave = keypair(ECPrivkey(privkey_bytes=b'\x44'*32))
|
||||
self.eve = keypair(ECPrivkey(privkey_bytes=b'\x45'*32))
|
||||
|
||||
async def run_test1(self, t):
|
||||
t1 = t.submit_send(
|
||||
payload={'message': {'text': 'alice_timeout'.encode('utf-8')}},
|
||||
node_id_or_blinded_path=self.alice_pub)
|
||||
node_id_or_blinded_path=self.alice.pubkey)
|
||||
|
||||
with self.assertRaises(Timeout):
|
||||
await t1
|
||||
@@ -326,7 +326,7 @@ class TestOnionMessageManager(ElectrumTestCase):
|
||||
async def run_test2(self, t):
|
||||
t2 = t.submit_send(
|
||||
payload={'message': {'text': 'bob_slow_timeout'.encode('utf-8')}},
|
||||
node_id_or_blinded_path=self.bob_pub)
|
||||
node_id_or_blinded_path=self.bob.pubkey)
|
||||
|
||||
with self.assertRaises(Timeout):
|
||||
await t2
|
||||
@@ -334,7 +334,7 @@ class TestOnionMessageManager(ElectrumTestCase):
|
||||
async def run_test3(self, t, rkey):
|
||||
t3 = t.submit_send(
|
||||
payload={'message': {'text': 'carol_with_immediate_reply'.encode('utf-8')}},
|
||||
node_id_or_blinded_path=self.carol_pub,
|
||||
node_id_or_blinded_path=self.carol.pubkey,
|
||||
key=rkey)
|
||||
|
||||
t3_result = await t3
|
||||
@@ -343,7 +343,7 @@ class TestOnionMessageManager(ElectrumTestCase):
|
||||
async def run_test4(self, t, rkey):
|
||||
t4 = t.submit_send(
|
||||
payload={'message': {'text': 'dave_with_slow_reply'.encode('utf-8')}},
|
||||
node_id_or_blinded_path=self.dave_pub,
|
||||
node_id_or_blinded_path=self.dave.pubkey,
|
||||
key=rkey)
|
||||
|
||||
t4_result = await t4
|
||||
@@ -352,34 +352,34 @@ class TestOnionMessageManager(ElectrumTestCase):
|
||||
async def run_test5(self, t):
|
||||
t5 = t.submit_send(
|
||||
payload={'message': {'text': 'no_peer'.encode('utf-8')}},
|
||||
node_id_or_blinded_path=self.eve_pub)
|
||||
node_id_or_blinded_path=self.eve.pubkey)
|
||||
|
||||
with self.assertRaises(NoRouteFound):
|
||||
await t5
|
||||
|
||||
async def test_manager(self):
|
||||
async def test_request_and_reply(self):
|
||||
n = MockNetwork()
|
||||
k = keypair()
|
||||
q1, q2 = asyncio.Queue(), asyncio.Queue()
|
||||
lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test', has_anchors=False)
|
||||
lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test_request_and_reply', has_anchors=False)
|
||||
|
||||
def slow(*args, **kwargs):
|
||||
time.sleep(2*TIME_STEP)
|
||||
|
||||
def withreply(key, *args, **kwargs):
|
||||
t.on_onion_message_received_reply({'path_id': {'data': b'electrum' + key}}, {})
|
||||
t.on_onion_message_received({'path_id': {'data': b'electrum' + key}}, {})
|
||||
|
||||
def slowwithreply(key, *args, **kwargs):
|
||||
time.sleep(2*TIME_STEP)
|
||||
t.on_onion_message_received_reply({'path_id': {'data': b'electrum' + key}}, {})
|
||||
t.on_onion_message_received({'path_id': {'data': b'electrum' + key}}, {})
|
||||
|
||||
rkey1 = bfh('0102030405060708')
|
||||
rkey2 = bfh('0102030405060709')
|
||||
|
||||
lnw.peers[self.alice_pub] = MockPeer(self.alice_pub)
|
||||
lnw.peers[self.bob_pub] = MockPeer(self.bob_pub, on_send_message=slow)
|
||||
lnw.peers[self.carol_pub] = MockPeer(self.carol_pub, on_send_message=partial(withreply, rkey1))
|
||||
lnw.peers[self.dave_pub] = MockPeer(self.dave_pub, on_send_message=partial(slowwithreply, rkey2))
|
||||
lnw.peers[self.alice.pubkey] = MockPeer(self.alice.pubkey)
|
||||
lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=slow)
|
||||
lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(withreply, rkey1))
|
||||
lnw.peers[self.dave.pubkey] = MockPeer(self.dave.pubkey, on_send_message=partial(slowwithreply, rkey2))
|
||||
t = OnionMessageManager(lnw)
|
||||
t.start_network(network=n)
|
||||
|
||||
@@ -404,3 +404,66 @@ class TestOnionMessageManager(ElectrumTestCase):
|
||||
self.logger.debug('stopping manager')
|
||||
await t.stop()
|
||||
await lnw.stop()
|
||||
|
||||
async def test_forward(self):
|
||||
n = MockNetwork()
|
||||
q1 = asyncio.Queue()
|
||||
lnw = MockLNWallet(local_keypair=self.alice, chans=[], tx_queue=q1, name='alice', has_anchors=False)
|
||||
|
||||
self.was_sent = False
|
||||
|
||||
def on_send(to: str, *args, **kwargs):
|
||||
self.assertEqual(to, 'bob')
|
||||
self.was_sent = True
|
||||
# validate what's sent to bob
|
||||
self.assertEqual(bfh(HOPS[1]['E']), kwargs['blinding'])
|
||||
message_type, payload = decode_msg(bfh(test_vectors['decrypt']['hops'][1]['onion_message']))
|
||||
self.assertEqual(message_type, 'onion_message')
|
||||
self.assertEqual(payload['onion_message_packet'], kwargs['onion_message_packet'])
|
||||
|
||||
lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=partial(on_send, 'bob'))
|
||||
lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(on_send, 'carol'))
|
||||
t = OnionMessageManager(lnw)
|
||||
t.start_network(network=n)
|
||||
|
||||
onionmsg = bfh(test_vectors['onionmessage']['onion_message_packet'])
|
||||
try:
|
||||
t.on_onion_message({
|
||||
'blinding': bfh(test_vectors['route']['blinding']),
|
||||
'len': len(onionmsg),
|
||||
'onion_message_packet': onionmsg
|
||||
})
|
||||
finally:
|
||||
await asyncio.sleep(TIME_STEP)
|
||||
|
||||
self.logger.debug('stopping manager')
|
||||
await t.stop()
|
||||
await lnw.stop()
|
||||
|
||||
self.assertTrue(self.was_sent)
|
||||
|
||||
async def test_receive_unsolicited(self):
|
||||
n = MockNetwork()
|
||||
q1 = asyncio.Queue()
|
||||
lnw = MockLNWallet(local_keypair=self.dave, chans=[], tx_queue=q1, name='dave', has_anchors=False)
|
||||
|
||||
t = OnionMessageManager(lnw)
|
||||
t.start_network(network=n)
|
||||
|
||||
self.received_unsolicited = False
|
||||
|
||||
def my_on_onion_message_received_unsolicited(*args, **kwargs):
|
||||
self.received_unsolicited = True
|
||||
|
||||
t.on_onion_message_received_unsolicited = my_on_onion_message_received_unsolicited
|
||||
packet = bfh(test_vectors['decrypt']['hops'][3]['onion_message'])
|
||||
message_type, payload = decode_msg(packet)
|
||||
try:
|
||||
t.on_onion_message(payload)
|
||||
self.assertTrue(self.received_unsolicited)
|
||||
finally:
|
||||
await asyncio.sleep(TIME_STEP)
|
||||
|
||||
self.logger.debug('stopping manager')
|
||||
await t.stop()
|
||||
await lnw.stop()
|
||||
|
||||
Reference in New Issue
Block a user