1
0

lnworker: split LNWallet and LNWorker: LNWallet "has an" LNWorker

- LNWallet no longer "is-an" LNWorker, instead LNWallet "has-an" LNWorker
- the motivation is to make the unit tests nicer, and allow writing unit tests for more things
  - I hope this makes it possible to e.g. test lnsweep in the unit tests
  - some stuff we would previously have to write a regtest for, maybe we can write a unit test for, now
- in unit tests, MockLNWallet now
  - inherits LNWallet
  - the Wallet is no longer being mocked
This commit is contained in:
SomberNight
2025-12-17 15:16:05 +00:00
parent bdcd3f9c7c
commit 1006e8092f
17 changed files with 345 additions and 354 deletions

View File

@@ -760,22 +760,23 @@ class TestCommandsTestnet(ElectrumTestCase):
# Mock the network and lnworker
mock_lnworker = mock.Mock()
mock_lnworker.lnpeermgr = mock.Mock()
w.lnworker = mock_lnworker
mock_peer = mock.Mock()
mock_peer.initialized = asyncio.Future()
connection_string = "test_node_id@127.0.0.1:9735"
called = False
async def lnworker_add_peer(*args, **kwargs):
async def lnpeermgr_add_peer(*args, **kwargs):
assert args[0] == connection_string
nonlocal called
called += 1
return mock_peer
mock_lnworker.add_peer = lnworker_add_peer
mock_lnworker.lnpeermgr.add_peer = lnpeermgr_add_peer
# check if add_peer times out if peer doesn't initialize (LN_P2P_NETWORK_TIMEOUT is 0.001s)
with self.assertRaises(UserFacingException):
await cmds.add_peer(connection_string=connection_string, wallet=w)
# check if add_peer called lnworker.add_peer
# check if add_peer called lnpeermgr.add_peer
assert called == 1
mock_peer.initialized = asyncio.Future()

View File

@@ -23,6 +23,7 @@
# (around commit 42de4400bff5105352d0552155f73589166d162b).
import unittest
from functools import lru_cache
from unittest import mock
import os
import binascii
@@ -40,7 +41,7 @@ from electrum.crypto import privkey_to_pubkey
from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED, UpdateAddHtlc
from electrum.lnutil import effective_htlc_tx_weight
from electrum.logging import console_stderr_handler
from electrum.lnchannel import ChannelState
from electrum.lnchannel import ChannelState, Channel
from electrum.json_db import StoredDict
from electrum.coinchooser import PRNG
@@ -124,6 +125,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
return StoredDict(state, None)
@lru_cache()
def bip32(sequence):
node = bip32_utils.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence)
k = node.eckey.get_secret_bytes()
@@ -137,7 +139,7 @@ def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None,
alice_pubkey=b"\x01"*33, bob_pubkey=b"\x02"*33, random_seed=None,
anchor_outputs=False,
local_max_inflight=None, remote_max_inflight=None,
max_accepted_htlcs=5):
max_accepted_htlcs=5) -> tuple[Channel, Channel]:
if random_seed is None: # needed for deterministic randomness
random_seed = os.urandom(32)
random_gen = PRNG(random_seed)

View File

@@ -10,6 +10,7 @@ from collections import defaultdict
import logging
import concurrent
from concurrent import futures
from functools import lru_cache
from unittest import mock
from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence
from types import MappingProxyType
@@ -24,6 +25,7 @@ import electrum.trampoline
from electrum import bitcoin
from electrum import util
from electrum import constants
from electrum import bip32
from electrum.network import Network
from electrum import simple_config, lnutil
from electrum.lnaddr import lnencode, LnAddr, lndecode
@@ -37,7 +39,7 @@ from electrum.lnutil import Keypair, PaymentFailure, LnFeatures, HTLCOwner, Paym
from electrum.lnchannel import ChannelState, PeerState, Channel
from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession
from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession, LNPeerManager
from electrum.lnmsg import encode_msg, decode_msg
from electrum import lnmsg
from electrum.logging import console_stderr_handler, Logger
@@ -49,10 +51,11 @@ from electrum.interface import GracefulDisconnect
from electrum.simple_config import SimpleConfig
from electrum.fee_policy import FeeTimeEstimates, FEE_ETA_TARGETS
from electrum.mpp_split import split_amount_normal
from electrum.wallet import Abstract_Wallet
from .test_lnchannel import create_test_channels
from .test_bitcoin import needs_test_with_all_chacha20_implementations
from . import ElectrumTestCase
from . import ElectrumTestCase, restore_wallet_from_text__for_unittest
def keypair():
@@ -62,9 +65,6 @@ def keypair():
privkey=priv)
return k1
@contextmanager
def noop_lock():
yield
class MockNetwork:
def __init__(self, tx_queue, *, config: SimpleConfig):
@@ -120,144 +120,100 @@ class MockADB:
def get_local_height(self):
return self._blockchain.height()
class MockWallet:
receive_requests = {}
adb = MockADB()
def get_invoice(self, key):
pass
def get_request(self, key):
pass
def get_key_for_receive_request(self, x):
pass
def set_label(self, x, y):
pass
def save_db(self):
pass
def is_lightning_backup(self):
return False
def is_mine(self, addr):
return True
def get_fingerprint(self):
return ''
def get_new_sweep_address_for_channel(self):
# note: sweep is not tested here, only in regtest
return "tb1qqu5newtapamjchgxf0nty6geuykhvwas45q4q4"
def is_up_to_date(self):
return True
class MockLNGossip:
def get_sync_progress_estimate(self):
return None, None, None
class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
class MockLNPeerManager(LNPeerManager):
def __init__(
self,
*,
node_keypair,
config: SimpleConfig,
features: LnFeatures,
lnwallet: LNWallet,
network: 'MockNetwork',
):
LNPeerManager.__init__(
self,
node_keypair=node_keypair,
lnwallet_or_lngossip=lnwallet,
features=features,
config=config,
)
self.network = network
@lru_cache()
def _bip32_from_name(name: str) -> bip32.BIP32Node:
# note: unlike a serialized xprv, the bip32 node can be cached easily,
# as it does not depend on constant.net (testnet/mainnet) network bytes
sequence = [ord(c) for c in name]
bip32_node = bip32.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence)
return bip32_node
class MockLNWallet(LNWallet):
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
PAYMENT_TIMEOUT = 120
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0
MPP_SPLIT_PART_FRACTION = 1 # this disables the forced splitting
MPP_SPLIT_PART_MINAMT_MSAT = 5_000_000
def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name, has_anchors):
def __init__(self, *, tx_queue, name, has_anchors, ln_xprv: str = None):
self.name = name
Logger.__init__(self)
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
self.node_keypair = local_keypair
self.payment_secret_key = os.urandom(32) # does not need to be deterministic in tests
self._user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-")
self.config = SimpleConfig({}, read_user_dir_function=lambda: self._user_dir)
self.network = MockNetwork(tx_queue, config=self.config)
self.taskgroup = OldTaskGroup()
self.config.ENABLE_ANCHOR_CHANNELS = has_anchors
self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0
network = MockNetwork(tx_queue, config=self.config)
wallet = restore_wallet_from_text__for_unittest(
"9dk", path=None, passphrase=name, config=self.config)['wallet'] # type: Abstract_Wallet
wallet.is_up_to_date = lambda: True
wallet.adb.network = wallet.network = network
features = LnFeatures(0)
features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
features |= LnFeatures.VAR_ONION_OPT
features |= LnFeatures.PAYMENT_SECRET_OPT
features |= LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
features |= LnFeatures.OPTION_CHANNEL_TYPE_OPT
features |= LnFeatures.OPTION_SCID_ALIAS_OPT
features |= LnFeatures.OPTION_STATIC_REMOTEKEY_OPT
if ln_xprv is None:
ln_xprv = _bip32_from_name(name).to_xprv()
LNWallet.__init__(self, wallet=wallet, xprv=ln_xprv, features=features)
self.lnpeermgr = MockLNPeerManager(
node_keypair=self.node_keypair,
config=self.config,
features=features,
lnwallet=self,
network=network,
)
self.lnwatcher = None
self.swap_manager = None
self.onion_message_manager = None
self.listen_server = None
self._channels = {chan.channel_id: chan for chan in chans}
self.payment_info = {}
self.logs = defaultdict(list)
self.wallet = MockWallet()
self.features = LnFeatures(0)
self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
self.features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
self.features |= LnFeatures.VAR_ONION_OPT
self.features |= LnFeatures.PAYMENT_SECRET_OPT
self.features |= LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
self.features |= LnFeatures.OPTION_CHANNEL_TYPE_OPT
self.features |= LnFeatures.OPTION_SCID_ALIAS_OPT
self.features |= LnFeatures.OPTION_STATIC_REMOTEKEY_OPT
self.config.ENABLE_ANCHOR_CHANNELS = has_anchors
for chan in chans:
chan.lnworker = self
self._peers = {} # bytes -> Peer
# used in tests
self.enable_htlc_settle = True
self.enable_htlc_forwarding = True
self.received_mpp_htlcs = dict()
self._paysessions = dict()
self.sent_htlcs_info = dict()
self.sent_buckets = defaultdict(set)
self.active_forwardings = {}
self.forwarding_failures = {}
self.inflight_payments = set()
self._preimages = {}
self.stopping_soon = False
self.downstream_to_upstream_htlc = {}
self.dont_expire_htlcs = {}
self.dont_settle_htlcs = {}
self.hold_invoice_callbacks = {}
self._payment_bundles_pkey_to_canon = {} # type: Dict[bytes, bytes]
self._payment_bundles_canon_to_pkeylist = {} # type: Dict[bytes, Sequence[bytes]]
self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0
self._channel_sending_capacity_lock = asyncio.Lock()
self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}")
self.logger.info(f"created LNWallet[{name}] with nodeID={self.node_keypair.pubkey.hex()}")
def clear_invoices_cache(self):
pass
def _add_channel(self, chan: Channel):
self._channels[chan.channel_id] = chan
chan.lnworker = self
def get_invoice_status(self, key):
pass
@property
def lock(self):
return noop_lock()
@property
def channel_db(self):
return self.network.channel_db if self.network else None
def uses_trampoline(self):
return not bool(self.channel_db)
@property
def channels(self):
return self._channels
@property
def peers(self):
return self._peers
def get_channel_by_short_id(self, short_channel_id):
with self.lock:
for chan in self._channels.values():
if chan.short_channel_id == short_channel_id:
return chan
def channel_state_changed(self, chan):
pass
@LNWallet.features.setter
def features(self, value):
self.lnpeermgr.features = value
def save_channel(self, chan):
print("Ignoring channel save")
pass
#print("Ignoring channel save")
def diagnostic_name(self):
return self.name
@@ -290,69 +246,6 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
budget=PaymentFeeBudget.from_invoice_amount(invoice_amount_msat=amount_msat, config=self.config),
)]
get_payments = LNWallet.get_payments
get_payment_secret = LNWallet.get_payment_secret
get_payment_info = LNWallet.get_payment_info
save_payment_info = LNWallet.save_payment_info
set_invoice_status = LNWallet.set_invoice_status
set_request_status = LNWallet.set_request_status
set_payment_status = LNWallet.set_payment_status
get_payment_status = LNWallet.get_payment_status
htlc_fulfilled = LNWallet.htlc_fulfilled
htlc_failed = LNWallet.htlc_failed
save_preimage = LNWallet.save_preimage
get_preimage = LNWallet.get_preimage
create_route_for_single_htlc = LNWallet.create_route_for_single_htlc
create_routes_for_payment = LNWallet.create_routes_for_payment
_check_bolt11_invoice = LNWallet._check_bolt11_invoice
pay_to_route = LNWallet.pay_to_route
pay_to_node = LNWallet.pay_to_node
pay_invoice = LNWallet.pay_invoice
force_close_channel = LNWallet.force_close_channel
schedule_force_closing = LNWallet.schedule_force_closing
on_peer_successfully_established = LNWallet.on_peer_successfully_established
get_channel_by_id = LNWallet.get_channel_by_id
channels_for_peer = LNWallet.channels_for_peer
calc_routing_hints_for_invoice = LNWallet.calc_routing_hints_for_invoice
get_channels_for_receiving = LNWallet.get_channels_for_receiving
handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
is_trampoline_peer = LNWallet.is_trampoline_peer
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
#on_event_proxy_set = LNWallet.on_event_proxy_set
_decode_channel_update_msg = LNWallet._decode_channel_update_msg
_handle_chanupd_from_failed_htlc = LNWallet._handle_chanupd_from_failed_htlc
is_forwarded_htlc = LNWallet.is_forwarded_htlc
notify_upstream_peer = LNWallet.notify_upstream_peer
_force_close_channel = LNWallet._force_close_channel
suggest_payment_splits = LNWallet.suggest_payment_splits
register_hold_invoice = LNWallet.register_hold_invoice
unregister_hold_invoice = LNWallet.unregister_hold_invoice
add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice
update_or_create_mpp_with_received_htlc = LNWallet.update_or_create_mpp_with_received_htlc
set_mpp_resolution = LNWallet.set_mpp_resolution
get_mpp_amounts = LNWallet.get_mpp_amounts
bundle_payments = LNWallet.bundle_payments
get_payment_bundle = LNWallet.get_payment_bundle
_get_payment_key = LNWallet._get_payment_key
save_forwarding_failure = LNWallet.save_forwarding_failure
get_forwarding_failure = LNWallet.get_forwarding_failure
maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding
current_target_feerate_per_kw = LNWallet.current_target_feerate_per_kw
current_low_feerate_per_kw_srk_channel = LNWallet.current_low_feerate_per_kw_srk_channel
create_onion_for_route = LNWallet.create_onion_for_route
maybe_forward_htlc_set = LNWallet.maybe_forward_htlc_set
_maybe_forward_htlc = LNWallet._maybe_forward_htlc
_maybe_forward_trampoline = LNWallet._maybe_forward_trampoline
_maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created = LNWallet._maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created
set_htlc_set_error = LNWallet.set_htlc_set_error
is_payment_bundle_complete = LNWallet.is_payment_bundle_complete
delete_payment_bundle = LNWallet.delete_payment_bundle
_process_htlc_log = LNWallet._process_htlc_log
_get_invoice_features = LNWallet._get_invoice_features
receive_requires_jit_channel = LNWallet.receive_requires_jit_channel
can_get_zeroconf_channel = LNWallet.can_get_zeroconf_channel
class MockTransport:
def __init__(self, name):
@@ -667,25 +560,24 @@ class TestPeerDirect(TestPeer):
def prepare_peers(
self, alice_channel: Channel, bob_channel: Channel,
*, k1: Keypair = None, k2: Keypair = None,
):
if k1 is None:
k1 = keypair()
if k2 is None:
k2 = keypair()
q1, q2 = asyncio.Queue(), asyncio.Queue()
w1 = MockLNWallet(tx_queue=q1, name=bob_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
w2 = MockLNWallet(tx_queue=q2, name=alice_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
k1 = w1.node_keypair
k2 = w2.node_keypair
alice_channel.node_id = k2.pubkey
bob_channel.node_id = k1.pubkey
alice_channel.storage['node_id'] = alice_channel.node_id
bob_channel.storage['node_id'] = bob_channel.node_id
t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name)
q1, q2 = asyncio.Queue(), asyncio.Queue()
w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1, name=bob_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2, name=alice_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
w1._add_channel(alice_channel)
w2._add_channel(bob_channel)
self._lnworkers_created.extend([w1, w2])
p1 = PeerInTests(w1, k2.pubkey, t1)
p2 = PeerInTests(w2, k1.pubkey, t2)
w1._peers[p1.pubkey] = p1
w2._peers[p2.pubkey] = p2
w1.lnpeermgr._peers[p1.pubkey] = p1
w2.lnpeermgr._peers[p2.pubkey] = p2
# mark_open won't work if state is already OPEN.
# so set it to FUNDED
alice_channel._state = ChannelState.FUNDED
@@ -790,10 +682,9 @@ class TestPeerDirect(TestPeer):
----sig-->
"""
chan_AB, chan_BA = create_test_channels()
k1, k2 = keypair(), keypair()
# note: we don't start peer.htlc_switch() so that the fake htlcs are left alone.
async def f():
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA)
async with OldTaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p2._message_loop())
@@ -807,7 +698,7 @@ class TestPeerDirect(TestPeer):
await group.cancel_remaining()
# simulating disconnection. recreate transports.
self.logger.info("simulating disconnection. recreating transports.")
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA)
for chan in (chan_AB, chan_BA):
chan.peer_state = PeerState.DISCONNECTED
async with OldTaskGroup() as group:
@@ -846,10 +737,9 @@ class TestPeerDirect(TestPeer):
----rev-->
"""
chan_AB, chan_BA = create_test_channels()
k1, k2 = keypair(), keypair()
# note: we don't start peer.htlc_switch() so that the fake htlcs are left alone.
async def f():
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA)
async with OldTaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p2._message_loop())
@@ -864,7 +754,7 @@ class TestPeerDirect(TestPeer):
await group.cancel_remaining()
# simulating disconnection. recreate transports.
self.logger.info("simulating disconnection. recreating transports.")
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA)
for chan in (chan_AB, chan_BA):
chan.peer_state = PeerState.DISCONNECTED
async with OldTaskGroup() as group:
@@ -1788,7 +1678,7 @@ class TestPeerDirect(TestPeer):
with self.assertRaises(NoPathFound) as e:
await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)
peer = w1.peers[route[0].node_id]
peer = w1.lnpeermgr._peers[route[0].node_id]
# AssertionError is ok since we shouldn't use old routes, and the
# route finding should fail when channel is closed
async def f():
@@ -2126,12 +2016,19 @@ class TestPeerDirect(TestPeer):
class TestPeerForwarding(TestPeer):
def prepare_chans_and_peers_in_graph(self, graph_definition) -> Graph:
keys = {k: keypair() for k in graph_definition}
workers = {} # type: Dict[str, MockLNWallet]
txs_queues = {k: asyncio.Queue() for k in graph_definition}
# create workers
for a, definition in graph_definition.items():
workers[a] = MockLNWallet(tx_queue=txs_queues[a], name=a, has_anchors=self.TEST_ANCHOR_CHANNELS)
self._lnworkers_created.extend(list(workers.values()))
keys = {name: w.node_keypair for name, w in workers.items()}
channels = {} # type: Dict[Tuple[str, str], Channel]
transports = {}
workers = {} # type: Dict[str, MockLNWallet]
peers = {}
# create channels
for a, definition in graph_definition.items():
for b, channel_def in definition.get('channels', {}).items():
@@ -2145,6 +2042,8 @@ class TestPeerForwarding(TestPeer):
anchor_outputs=self.TEST_ANCHOR_CHANNELS
)
channels[(a, b)], channels[(b, a)] = channel_ab, channel_ba
workers[a]._add_channel(channel_ab)
workers[b]._add_channel(channel_ba)
transport_ab, transport_ba = transport_pair(keys[a], keys[b], channel_ab.name, channel_ba.name)
transports[(a, b)], transports[(b, a)] = transport_ab, transport_ba
# set fees
@@ -2153,12 +2052,6 @@ class TestPeerForwarding(TestPeer):
channel_ba.forwarding_fee_proportional_millionths = channel_def['remote_fee_rate_millionths']
channel_ba.forwarding_fee_base_msat = channel_def['remote_base_fee_msat']
# create workers and peers
for a, definition in graph_definition.items():
channels_of_node = [c for k, c in channels.items() if k[0] == a]
workers[a] = MockLNWallet(local_keypair=keys[a], chans=channels_of_node, tx_queue=txs_queues[a], name=a, has_anchors=self.TEST_ANCHOR_CHANNELS)
self._lnworkers_created.extend(list(workers.values()))
# create peers
for ab in channels.keys():
peers[ab] = Peer(workers[ab[0]], keys[ab[1]].pubkey, transports[ab])
@@ -2167,7 +2060,7 @@ class TestPeerForwarding(TestPeer):
for a, w in workers.items():
for ab, peer_ab in peers.items():
if ab[0] == a:
w._peers[peer_ab.pubkey] = peer_ab
w.lnpeermgr._peers[peer_ab.pubkey] = peer_ab
# set forwarding properties
for a, definition in graph_definition.items():

View File

@@ -352,9 +352,8 @@ class TestOnionMessageManager(ElectrumTestCase):
async def test_request_and_reply(self):
n = MockNetwork()
k = keypair()
q1, q2 = asyncio.Queue(), asyncio.Queue()
lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test_request_and_reply', has_anchors=False)
lnw = MockLNWallet(tx_queue=q1, name='test_request_and_reply', has_anchors=False)
def slow(*args, **kwargs):
time.sleep(2*TIME_STEP)
@@ -369,10 +368,10 @@ class TestOnionMessageManager(ElectrumTestCase):
rkey1 = bfh('0102030405060708')
rkey2 = bfh('0102030405060709')
lnw.peers[self.alice.pubkey] = MockPeer(self.alice.pubkey)
lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=slow)
lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(withreply, rkey1))
lnw.peers[self.dave.pubkey] = MockPeer(self.dave.pubkey, on_send_message=partial(slowwithreply, rkey2))
lnw.lnpeermgr._peers[self.alice.pubkey] = MockPeer(self.alice.pubkey)
lnw.lnpeermgr._peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=slow)
lnw.lnpeermgr._peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(withreply, rkey1))
lnw.lnpeermgr._peers[self.dave.pubkey] = MockPeer(self.dave.pubkey, on_send_message=partial(slowwithreply, rkey2))
t = OnionMessageManager(lnw)
t.start_network(network=n)
@@ -401,7 +400,8 @@ class TestOnionMessageManager(ElectrumTestCase):
async def test_forward(self):
n = MockNetwork()
q1 = asyncio.Queue()
lnw = MockLNWallet(local_keypair=self.alice, chans=[], tx_queue=q1, name='alice', has_anchors=False)
lnw = MockLNWallet(tx_queue=q1, name='alice', has_anchors=False)
lnw.node_keypair = self.alice
self.was_sent = False
@@ -414,8 +414,8 @@ class TestOnionMessageManager(ElectrumTestCase):
self.assertEqual(message_type, 'onion_message')
self.assertEqual(payload['onion_message_packet'], kwargs['onion_message_packet'])
lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=partial(on_send, 'bob'))
lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(on_send, 'carol'))
lnw.lnpeermgr._peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=partial(on_send, 'bob'))
lnw.lnpeermgr._peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(on_send, 'carol'))
t = OnionMessageManager(lnw)
t.start_network(network=n)
@@ -438,7 +438,8 @@ class TestOnionMessageManager(ElectrumTestCase):
async def test_receive_unsolicited(self):
n = MockNetwork()
q1 = asyncio.Queue()
lnw = MockLNWallet(local_keypair=self.dave, chans=[], tx_queue=q1, name='dave', has_anchors=False)
lnw = MockLNWallet(tx_queue=q1, name='dave', has_anchors=False)
lnw.node_keypair = self.dave
t = OnionMessageManager(lnw)
t.start_network(network=n)