onion_messages: add tests for forwards, receive unsolicited.
This commit is contained in:
@@ -20,6 +20,7 @@
|
|||||||
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import io
|
import io
|
||||||
@@ -43,7 +44,6 @@ from electrum.lnutil import LnFeatures
|
|||||||
from electrum.util import OldTaskGroup, log_exceptions
|
from electrum.util import OldTaskGroup, log_exceptions
|
||||||
|
|
||||||
|
|
||||||
# do not use util.now, because it rounds to integers
|
|
||||||
def now():
|
def now():
|
||||||
return time.time()
|
return time.time()
|
||||||
|
|
||||||
@@ -411,12 +411,6 @@ def get_blinded_reply_paths(
|
|||||||
class Timeout(Exception): pass
|
class Timeout(Exception): pass
|
||||||
|
|
||||||
|
|
||||||
class OnionMessageRequest(NamedTuple):
|
|
||||||
future: asyncio.Future
|
|
||||||
payload: bytes
|
|
||||||
node_id_or_blinded_path: bytes
|
|
||||||
|
|
||||||
|
|
||||||
class OnionMessageManager(Logger):
|
class OnionMessageManager(Logger):
|
||||||
"""handle state around onion message sends and receives.
|
"""handle state around onion message sends and receives.
|
||||||
- one instance per (ln)wallet
|
- one instance per (ln)wallet
|
||||||
@@ -437,12 +431,17 @@ class OnionMessageManager(Logger):
|
|||||||
FORWARD_RETRY_DELAY = 2
|
FORWARD_RETRY_DELAY = 2
|
||||||
FORWARD_MAX_QUEUE = 3
|
FORWARD_MAX_QUEUE = 3
|
||||||
|
|
||||||
|
class Request(NamedTuple):
|
||||||
|
future: asyncio.Future
|
||||||
|
payload: dict
|
||||||
|
node_id_or_blinded_path: bytes
|
||||||
|
|
||||||
def __init__(self, lnwallet: 'LNWallet'):
|
def __init__(self, lnwallet: 'LNWallet'):
|
||||||
Logger.__init__(self)
|
Logger.__init__(self)
|
||||||
self.network = None # type: Optional['Network']
|
self.network = None # type: Optional['Network']
|
||||||
self.taskgroup = None # type: OldTaskGroup
|
self.taskgroup = None # type: OldTaskGroup
|
||||||
self.lnwallet = lnwallet
|
self.lnwallet = lnwallet
|
||||||
self.pending = {} # type: dict[bytes, OnionMessageRequest]
|
self.pending = {} # type: dict[bytes, OnionMessageManager.Request]
|
||||||
self.pending_lock = threading.Lock()
|
self.pending_lock = threading.Lock()
|
||||||
self.send_queue = asyncio.PriorityQueue()
|
self.send_queue = asyncio.PriorityQueue()
|
||||||
self.forward_queue = asyncio.PriorityQueue()
|
self.forward_queue = asyncio.PriorityQueue()
|
||||||
@@ -559,7 +558,7 @@ class OnionMessageManager(Logger):
|
|||||||
|
|
||||||
self.logger.debug(f'submit_send {key=} {payload=} {node_id_or_blinded_path=}')
|
self.logger.debug(f'submit_send {key=} {payload=} {node_id_or_blinded_path=}')
|
||||||
|
|
||||||
req = OnionMessageRequest(
|
req = OnionMessageManager.Request(
|
||||||
future=asyncio.Future(),
|
future=asyncio.Future(),
|
||||||
payload=payload,
|
payload=payload,
|
||||||
node_id_or_blinded_path=node_id_or_blinded_path
|
node_id_or_blinded_path=node_id_or_blinded_path
|
||||||
@@ -608,6 +607,19 @@ class OnionMessageManager(Logger):
|
|||||||
# TODO: use payload to determine prefix?
|
# TODO: use payload to determine prefix?
|
||||||
return b'electrum' + key
|
return b'electrum' + key
|
||||||
|
|
||||||
|
def _get_request_for_path_id(self, recipient_data: dict) -> Request:
|
||||||
|
path_id = recipient_data.get('path_id', {}).get('data')
|
||||||
|
if not path_id:
|
||||||
|
return None
|
||||||
|
if not path_id[:8] == b'electrum':
|
||||||
|
self.logger.warning('not a reply to our request (unknown path_id prefix)')
|
||||||
|
return None
|
||||||
|
key = path_id[8:]
|
||||||
|
req = self.pending.get(key)
|
||||||
|
if req is None:
|
||||||
|
self.logger.warning('not a reply to our request (unknown request)')
|
||||||
|
return req
|
||||||
|
|
||||||
def on_onion_message_received(self, recipient_data: dict, payload: dict) -> None:
|
def on_onion_message_received(self, recipient_data: dict, payload: dict) -> None:
|
||||||
# we are destination, sanity checks
|
# we are destination, sanity checks
|
||||||
# - if `encrypted_data_tlv` contains `allowed_features`:
|
# - if `encrypted_data_tlv` contains `allowed_features`:
|
||||||
@@ -622,25 +634,16 @@ class OnionMessageManager(Logger):
|
|||||||
# - if `path_id` is set and corresponds to a path the reader has previously published in a `reply_path`:
|
# - if `path_id` is set and corresponds to a path the reader has previously published in a `reply_path`:
|
||||||
# - if the onion message is not a reply to that previous onion:
|
# - if the onion message is not a reply to that previous onion:
|
||||||
# - MUST ignore the onion message
|
# - MUST ignore the onion message
|
||||||
# TODO: store path_id and lookup here
|
req = self._get_request_for_path_id(recipient_data)
|
||||||
if 'path_id' not in recipient_data:
|
if req is None:
|
||||||
# unsolicited onion_message
|
# unsolicited onion_message
|
||||||
self.on_onion_message_received_unsolicited(recipient_data, payload)
|
self.on_onion_message_received_unsolicited(recipient_data, payload)
|
||||||
else:
|
else:
|
||||||
self.on_onion_message_received_reply(recipient_data, payload)
|
self.on_onion_message_received_reply(req, recipient_data, payload)
|
||||||
|
|
||||||
def on_onion_message_received_reply(self, recipient_data: dict, payload: dict) -> None:
|
def on_onion_message_received_reply(self, request: Request, recipient_data: dict, payload: dict) -> None:
|
||||||
# check if this reply is associated with a known request
|
assert request is not None, 'Request is mandatory'
|
||||||
correl_data = recipient_data['path_id'].get('data')
|
request.future.set_result((recipient_data, payload))
|
||||||
if not correl_data[:8] == b'electrum':
|
|
||||||
self.logger.warning('not a reply to our request (unknown path_id prefix)')
|
|
||||||
return
|
|
||||||
key = correl_data[8:]
|
|
||||||
req = self.pending.get(key)
|
|
||||||
if req is None:
|
|
||||||
self.logger.warning('not a reply to our request (unknown request)')
|
|
||||||
return
|
|
||||||
req.future.set_result((recipient_data, payload))
|
|
||||||
|
|
||||||
def on_onion_message_received_unsolicited(self, recipient_data: dict, payload: dict) -> None:
|
def on_onion_message_received_unsolicited(self, recipient_data: dict, payload: dict) -> None:
|
||||||
self.logger.debug('unsolicited onion_message received')
|
self.logger.debug('unsolicited onion_message received')
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ from electrum.lnonion import (
|
|||||||
process_onion_packet, get_bolt04_onion_key, encrypt_onionmsg_data_tlv,
|
process_onion_packet, get_bolt04_onion_key, encrypt_onionmsg_data_tlv,
|
||||||
get_shared_secrets_along_route, new_onion_packet, ONION_MESSAGE_LARGE_SIZE,
|
get_shared_secrets_along_route, new_onion_packet, ONION_MESSAGE_LARGE_SIZE,
|
||||||
HOPS_DATA_SIZE, InvalidPayloadSize)
|
HOPS_DATA_SIZE, InvalidPayloadSize)
|
||||||
from electrum.crypto import get_ecdh
|
from electrum.crypto import get_ecdh, privkey_to_pubkey
|
||||||
from electrum.lnutil import LnFeatures
|
from electrum.lnutil import LnFeatures, Keypair
|
||||||
from electrum.onion_message import blinding_privkey, create_blinded_path, encrypt_onionmsg_tlv_hops_data, \
|
from electrum.onion_message import blinding_privkey, create_blinded_path, encrypt_onionmsg_tlv_hops_data, \
|
||||||
OnionMessageManager, NoRouteFound, Timeout
|
OnionMessageManager, NoRouteFound, Timeout
|
||||||
from electrum.util import bfh, read_json_file, OldTaskGroup, get_asyncio_loop
|
from electrum.util import bfh, read_json_file, OldTaskGroup, get_asyncio_loop
|
||||||
@@ -304,21 +304,21 @@ class TestOnionMessageManager(ElectrumTestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.alice = ECPrivkey(privkey_bytes=b'\x41'*32)
|
|
||||||
self.alice_pub = self.alice.get_public_key_bytes(compressed=True)
|
def keypair(privkey: ECPrivkey):
|
||||||
self.bob = ECPrivkey(privkey_bytes=b'\x42'*32)
|
priv = privkey.get_secret_bytes()
|
||||||
self.bob_pub = self.bob.get_public_key_bytes(compressed=True)
|
return Keypair(pubkey=privkey_to_pubkey(priv), privkey=priv)
|
||||||
self.carol = ECPrivkey(privkey_bytes=b'\x43'*32)
|
|
||||||
self.carol_pub = self.carol.get_public_key_bytes(compressed=True)
|
self.alice = keypair(ECPrivkey(privkey_bytes=b'\x41'*32))
|
||||||
self.dave = ECPrivkey(privkey_bytes=b'\x44'*32)
|
self.bob = keypair(ECPrivkey(privkey_bytes=b'\x42'*32))
|
||||||
self.dave_pub = self.dave.get_public_key_bytes(compressed=True)
|
self.carol = keypair(ECPrivkey(privkey_bytes=b'\x43'*32))
|
||||||
self.eve = ECPrivkey(privkey_bytes=b'\x45'*32)
|
self.dave = keypair(ECPrivkey(privkey_bytes=b'\x44'*32))
|
||||||
self.eve_pub = self.eve.get_public_key_bytes(compressed=True)
|
self.eve = keypair(ECPrivkey(privkey_bytes=b'\x45'*32))
|
||||||
|
|
||||||
async def run_test1(self, t):
|
async def run_test1(self, t):
|
||||||
t1 = t.submit_send(
|
t1 = t.submit_send(
|
||||||
payload={'message': {'text': 'alice_timeout'.encode('utf-8')}},
|
payload={'message': {'text': 'alice_timeout'.encode('utf-8')}},
|
||||||
node_id_or_blinded_path=self.alice_pub)
|
node_id_or_blinded_path=self.alice.pubkey)
|
||||||
|
|
||||||
with self.assertRaises(Timeout):
|
with self.assertRaises(Timeout):
|
||||||
await t1
|
await t1
|
||||||
@@ -326,7 +326,7 @@ class TestOnionMessageManager(ElectrumTestCase):
|
|||||||
async def run_test2(self, t):
|
async def run_test2(self, t):
|
||||||
t2 = t.submit_send(
|
t2 = t.submit_send(
|
||||||
payload={'message': {'text': 'bob_slow_timeout'.encode('utf-8')}},
|
payload={'message': {'text': 'bob_slow_timeout'.encode('utf-8')}},
|
||||||
node_id_or_blinded_path=self.bob_pub)
|
node_id_or_blinded_path=self.bob.pubkey)
|
||||||
|
|
||||||
with self.assertRaises(Timeout):
|
with self.assertRaises(Timeout):
|
||||||
await t2
|
await t2
|
||||||
@@ -334,7 +334,7 @@ class TestOnionMessageManager(ElectrumTestCase):
|
|||||||
async def run_test3(self, t, rkey):
|
async def run_test3(self, t, rkey):
|
||||||
t3 = t.submit_send(
|
t3 = t.submit_send(
|
||||||
payload={'message': {'text': 'carol_with_immediate_reply'.encode('utf-8')}},
|
payload={'message': {'text': 'carol_with_immediate_reply'.encode('utf-8')}},
|
||||||
node_id_or_blinded_path=self.carol_pub,
|
node_id_or_blinded_path=self.carol.pubkey,
|
||||||
key=rkey)
|
key=rkey)
|
||||||
|
|
||||||
t3_result = await t3
|
t3_result = await t3
|
||||||
@@ -343,7 +343,7 @@ class TestOnionMessageManager(ElectrumTestCase):
|
|||||||
async def run_test4(self, t, rkey):
|
async def run_test4(self, t, rkey):
|
||||||
t4 = t.submit_send(
|
t4 = t.submit_send(
|
||||||
payload={'message': {'text': 'dave_with_slow_reply'.encode('utf-8')}},
|
payload={'message': {'text': 'dave_with_slow_reply'.encode('utf-8')}},
|
||||||
node_id_or_blinded_path=self.dave_pub,
|
node_id_or_blinded_path=self.dave.pubkey,
|
||||||
key=rkey)
|
key=rkey)
|
||||||
|
|
||||||
t4_result = await t4
|
t4_result = await t4
|
||||||
@@ -352,34 +352,34 @@ class TestOnionMessageManager(ElectrumTestCase):
|
|||||||
async def run_test5(self, t):
|
async def run_test5(self, t):
|
||||||
t5 = t.submit_send(
|
t5 = t.submit_send(
|
||||||
payload={'message': {'text': 'no_peer'.encode('utf-8')}},
|
payload={'message': {'text': 'no_peer'.encode('utf-8')}},
|
||||||
node_id_or_blinded_path=self.eve_pub)
|
node_id_or_blinded_path=self.eve.pubkey)
|
||||||
|
|
||||||
with self.assertRaises(NoRouteFound):
|
with self.assertRaises(NoRouteFound):
|
||||||
await t5
|
await t5
|
||||||
|
|
||||||
async def test_manager(self):
|
async def test_request_and_reply(self):
|
||||||
n = MockNetwork()
|
n = MockNetwork()
|
||||||
k = keypair()
|
k = keypair()
|
||||||
q1, q2 = asyncio.Queue(), asyncio.Queue()
|
q1, q2 = asyncio.Queue(), asyncio.Queue()
|
||||||
lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test', has_anchors=False)
|
lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test_request_and_reply', has_anchors=False)
|
||||||
|
|
||||||
def slow(*args, **kwargs):
|
def slow(*args, **kwargs):
|
||||||
time.sleep(2*TIME_STEP)
|
time.sleep(2*TIME_STEP)
|
||||||
|
|
||||||
def withreply(key, *args, **kwargs):
|
def withreply(key, *args, **kwargs):
|
||||||
t.on_onion_message_received_reply({'path_id': {'data': b'electrum' + key}}, {})
|
t.on_onion_message_received({'path_id': {'data': b'electrum' + key}}, {})
|
||||||
|
|
||||||
def slowwithreply(key, *args, **kwargs):
|
def slowwithreply(key, *args, **kwargs):
|
||||||
time.sleep(2*TIME_STEP)
|
time.sleep(2*TIME_STEP)
|
||||||
t.on_onion_message_received_reply({'path_id': {'data': b'electrum' + key}}, {})
|
t.on_onion_message_received({'path_id': {'data': b'electrum' + key}}, {})
|
||||||
|
|
||||||
rkey1 = bfh('0102030405060708')
|
rkey1 = bfh('0102030405060708')
|
||||||
rkey2 = bfh('0102030405060709')
|
rkey2 = bfh('0102030405060709')
|
||||||
|
|
||||||
lnw.peers[self.alice_pub] = MockPeer(self.alice_pub)
|
lnw.peers[self.alice.pubkey] = MockPeer(self.alice.pubkey)
|
||||||
lnw.peers[self.bob_pub] = MockPeer(self.bob_pub, on_send_message=slow)
|
lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=slow)
|
||||||
lnw.peers[self.carol_pub] = MockPeer(self.carol_pub, on_send_message=partial(withreply, rkey1))
|
lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(withreply, rkey1))
|
||||||
lnw.peers[self.dave_pub] = MockPeer(self.dave_pub, on_send_message=partial(slowwithreply, rkey2))
|
lnw.peers[self.dave.pubkey] = MockPeer(self.dave.pubkey, on_send_message=partial(slowwithreply, rkey2))
|
||||||
t = OnionMessageManager(lnw)
|
t = OnionMessageManager(lnw)
|
||||||
t.start_network(network=n)
|
t.start_network(network=n)
|
||||||
|
|
||||||
@@ -404,3 +404,66 @@ class TestOnionMessageManager(ElectrumTestCase):
|
|||||||
self.logger.debug('stopping manager')
|
self.logger.debug('stopping manager')
|
||||||
await t.stop()
|
await t.stop()
|
||||||
await lnw.stop()
|
await lnw.stop()
|
||||||
|
|
||||||
|
async def test_forward(self):
|
||||||
|
n = MockNetwork()
|
||||||
|
q1 = asyncio.Queue()
|
||||||
|
lnw = MockLNWallet(local_keypair=self.alice, chans=[], tx_queue=q1, name='alice', has_anchors=False)
|
||||||
|
|
||||||
|
self.was_sent = False
|
||||||
|
|
||||||
|
def on_send(to: str, *args, **kwargs):
|
||||||
|
self.assertEqual(to, 'bob')
|
||||||
|
self.was_sent = True
|
||||||
|
# validate what's sent to bob
|
||||||
|
self.assertEqual(bfh(HOPS[1]['E']), kwargs['blinding'])
|
||||||
|
message_type, payload = decode_msg(bfh(test_vectors['decrypt']['hops'][1]['onion_message']))
|
||||||
|
self.assertEqual(message_type, 'onion_message')
|
||||||
|
self.assertEqual(payload['onion_message_packet'], kwargs['onion_message_packet'])
|
||||||
|
|
||||||
|
lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=partial(on_send, 'bob'))
|
||||||
|
lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(on_send, 'carol'))
|
||||||
|
t = OnionMessageManager(lnw)
|
||||||
|
t.start_network(network=n)
|
||||||
|
|
||||||
|
onionmsg = bfh(test_vectors['onionmessage']['onion_message_packet'])
|
||||||
|
try:
|
||||||
|
t.on_onion_message({
|
||||||
|
'blinding': bfh(test_vectors['route']['blinding']),
|
||||||
|
'len': len(onionmsg),
|
||||||
|
'onion_message_packet': onionmsg
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
await asyncio.sleep(TIME_STEP)
|
||||||
|
|
||||||
|
self.logger.debug('stopping manager')
|
||||||
|
await t.stop()
|
||||||
|
await lnw.stop()
|
||||||
|
|
||||||
|
self.assertTrue(self.was_sent)
|
||||||
|
|
||||||
|
async def test_receive_unsolicited(self):
|
||||||
|
n = MockNetwork()
|
||||||
|
q1 = asyncio.Queue()
|
||||||
|
lnw = MockLNWallet(local_keypair=self.dave, chans=[], tx_queue=q1, name='dave', has_anchors=False)
|
||||||
|
|
||||||
|
t = OnionMessageManager(lnw)
|
||||||
|
t.start_network(network=n)
|
||||||
|
|
||||||
|
self.received_unsolicited = False
|
||||||
|
|
||||||
|
def my_on_onion_message_received_unsolicited(*args, **kwargs):
|
||||||
|
self.received_unsolicited = True
|
||||||
|
|
||||||
|
t.on_onion_message_received_unsolicited = my_on_onion_message_received_unsolicited
|
||||||
|
packet = bfh(test_vectors['decrypt']['hops'][3]['onion_message'])
|
||||||
|
message_type, payload = decode_msg(packet)
|
||||||
|
try:
|
||||||
|
t.on_onion_message(payload)
|
||||||
|
self.assertTrue(self.received_unsolicited)
|
||||||
|
finally:
|
||||||
|
await asyncio.sleep(TIME_STEP)
|
||||||
|
|
||||||
|
self.logger.debug('stopping manager')
|
||||||
|
await t.stop()
|
||||||
|
await lnw.stop()
|
||||||
|
|||||||
Reference in New Issue
Block a user