1
0

tests: lnpeer: fix cyclic lnworker.wallet.lnworker inconsistency

These better hold, lol:
wallet.lnworker.wallet == wallet
lnworker.wallet.lnworker == lnworker
This commit is contained in:
SomberNight
2025-12-19 14:36:22 +00:00
parent ea42b02ceb
commit 17f41044d5
4 changed files with 64 additions and 81 deletions

View File

@@ -4262,7 +4262,7 @@ class Wallet(object):
def __new__(cls, db: 'WalletDB', *, config: SimpleConfig) -> Abstract_Wallet:
wallet_type = db.get('wallet_type')
WalletClass = Wallet.wallet_class(wallet_type)
WalletClass = cls.wallet_class(wallet_type)
wallet = WalletClass(db, config=config)
return wallet
@@ -4320,6 +4320,7 @@ def restore_wallet_from_text(
encrypt_file: Optional[bool] = None,
gap_limit: Optional[int] = None,
gap_limit_for_change: Optional[int] = None,
wallet_factory = Wallet, # used in tests
) -> dict:
"""Restore a wallet from text. Text can be a seed phrase, a master
public key, a master private key, a list of bitcoin addresses
@@ -4365,7 +4366,7 @@ def restore_wallet_from_text(
db.put('gap_limit', gap_limit)
if gap_limit_for_change is not None:
db.put('gap_limit_for_change', gap_limit_for_change)
wallet = Wallet(db, config=config)
wallet = wallet_factory(db, config=config)
if db.storage:
assert not db.storage.file_exists(), "file was created too soon! plaintext keys might have been written to disk"
wallet.update_password(old_pw=None, new_pw=password, encrypt_storage=encrypt_file)

View File

@@ -175,12 +175,12 @@ def create_test_channels(
remote_max_inflight = funding_sat * 1000 if remote_max_inflight is None else remote_max_inflight
alice_raw = [bip32("m/" + str(i)) for i in range(5)]
bob_raw = [bip32("m/" + str(i)) for i in range(5,11)]
alice_privkeys = [lnutil.Keypair(privkey_to_pubkey(x), x) for x in alice_raw]
alice_privkeys = [lnutil.Keypair(privkey_to_pubkey(x), x) for x in alice_raw] # TODO make it depend on alice_lnwallet
bob_privkeys = [lnutil.Keypair(privkey_to_pubkey(x), x) for x in bob_raw]
alice_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in alice_privkeys]
bob_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in bob_privkeys]
alice_seed = random_gen.get_bytes(32)
alice_seed = random_gen.get_bytes(32) # TODO make it depend on alice_lnwallet
bob_seed = random_gen.get_bytes(32)
alice_first = lnutil.secret_to_pubkey(
@@ -201,7 +201,9 @@ def create_test_channels(
max_accepted_htlcs=max_accepted_htlcs,
),
name=f"{alice_name}->{bob_name}",
initial_feerate=feerate),
initial_feerate=feerate,
lnworker=alice_lnwallet,
),
lnchannel.Channel(
create_channel_state(
funding_txid, funding_index, funding_sat, False, remote_amount,
@@ -212,7 +214,9 @@ def create_test_channels(
max_accepted_htlcs=max_accepted_htlcs,
),
name=f"{bob_name}->{alice_name}",
initial_feerate=feerate)
initial_feerate=feerate,
lnworker=bob_lnwallet,
)
)
alice.hm.log[LOCAL]['ctn'] = 0

View File

@@ -51,7 +51,7 @@ from electrum.interface import GracefulDisconnect
from electrum.simple_config import SimpleConfig
from electrum.fee_policy import FeeTimeEstimates, FEE_ETA_TARGETS
from electrum.mpp_split import split_amount_normal
from electrum.wallet import Abstract_Wallet
from electrum.wallet import Abstract_Wallet, Standard_Wallet
from .test_lnchannel import create_test_channels
from .test_bitcoin import needs_test_with_all_chacha20_implementations
@@ -116,93 +116,70 @@ class MockLNGossip:
return None, None, None
class MockLNPeerManager(LNPeerManager):
def __init__(
self,
*,
node_keypair,
config: SimpleConfig,
features: LnFeatures,
lnwallet: LNWallet,
network: 'MockNetwork',
):
LNPeerManager.__init__(
self,
node_keypair=node_keypair,
lnwallet_or_lngossip=lnwallet,
features=features,
config=config,
)
self.network = network
class MockWalletFactory(electrum.wallet.Wallet):
@staticmethod
def wallet_class(wallet_type):
real_wallet_class = electrum.wallet.Wallet.wallet_class(wallet_type)
if real_wallet_class is Standard_Wallet:
return MockStandardWallet
return real_wallet_class
@lru_cache()
def _bip32_from_name(name: str) -> bip32.BIP32Node:
# note: unlike a serialized xprv, the bip32 node can be cached easily,
# as it does not depend on constant.net (testnet/mainnet) network bytes
sequence = [ord(c) for c in name]
bip32_node = bip32.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence)
return bip32_node
class MockStandardWallet(Standard_Wallet):
def _init_lnworker(self):
ln_xprv = self.db.get('lightning_xprv') or self.db.get('lightning_privkey2')
assert ln_xprv
self.lnworker = MockLNWallet(self, ln_xprv)
def basename(self):
passphrase = self.db.get("keystore").get("passphrase")
assert passphrase
return passphrase # lol, super secure name
def create_mock_lnwallet(*, name, has_anchors) -> 'MockLNWallet':
_user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-")
config = SimpleConfig({}, read_user_dir_function=lambda: _user_dir)
config.ENABLE_ANCHOR_CHANNELS = has_anchors
config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0
network = MockNetwork(config=config)
wallet = restore_wallet_from_text__for_unittest(
"9dk", path=None, passphrase=name, config=config,
wallet_factory=MockWalletFactory,
)['wallet'] # type: MockStandardWallet
wallet.is_up_to_date = lambda: True
wallet.adb.network = wallet.network = network
lnworker = wallet.lnworker
assert isinstance(lnworker, MockLNWallet), f"{lnworker=!r}"
lnworker._user_dir = _user_dir
lnworker.lnpeermgr.network = network
lnworker.logger.info(f"created LNWallet[{name}] with nodeID={lnworker.node_keypair.pubkey.hex()}")
return lnworker
class MockLNWallet(LNWallet):
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0
MPP_SPLIT_PART_FRACTION = 1 # this disables the forced splitting
def __init__(self, *, name, has_anchors):
self.name = name
self._user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-")
self.config = SimpleConfig({}, read_user_dir_function=lambda: self._user_dir)
self.config.ENABLE_ANCHOR_CHANNELS = has_anchors
self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0
network = MockNetwork(config=self.config)
wallet = restore_wallet_from_text__for_unittest(
"9dk", path=None, passphrase=name, config=self.config)['wallet'] # type: Abstract_Wallet
wallet.is_up_to_date = lambda: True
wallet.adb.network = wallet.network = network
#assert wallet.lnworker is None # FIXME xxxxx wallet already has another lnworker by now >.<
features = LnFeatures(0)
features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
features |= LnFeatures.VAR_ONION_OPT
features |= LnFeatures.PAYMENT_SECRET_OPT
features |= LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
features |= LnFeatures.OPTION_CHANNEL_TYPE_OPT
features |= LnFeatures.OPTION_SCID_ALIAS_OPT
features |= LnFeatures.OPTION_STATIC_REMOTEKEY_OPT
ln_xprv = _bip32_from_name(name).to_xprv()
LNWallet.__init__(self, wallet=wallet, xprv=ln_xprv, features=features)
self.lnpeermgr = MockLNPeerManager(
node_keypair=self.node_keypair,
config=self.config,
features=features,
lnwallet=self,
network=network,
)
self.logger.info(f"created LNWallet[{name}] with nodeID={self.node_keypair.pubkey.hex()}")
def __init__(self, *args, **kwargs):
LNWallet.__init__(self, *args, **kwargs)
self.features &= ~LnFeatures.BASIC_MPP_OPT # by default, disable MPP
def _add_channel(self, chan: Channel):
self._channels[chan.channel_id] = chan
# assert chan.lnworker == self # this fails as some tests are reusing chans in a weird way
chan.lnworker = self
@LNWallet.features.setter
def features(self, value):
self.lnpeermgr.features = value
def save_channel(self, chan):
pass
#print("Ignoring channel save")
def diagnostic_name(self):
return self.name
@property
def name(self):
return self.wallet.basename()
async def stop(self):
await LNWallet.stop(self)
@@ -524,7 +501,7 @@ class TestPeer(ElectrumTestCase):
def prepare_lnwallets(self, graph_definition) -> Mapping[str, MockLNWallet]:
workers = {} # type: Dict[str, MockLNWallet]
for a, definition in graph_definition.items():
workers[a] = MockLNWallet(name=a, has_anchors=self.TEST_ANCHOR_CHANNELS)
workers[a] = create_mock_lnwallet(name=a, has_anchors=self.TEST_ANCHOR_CHANNELS)
self._lnworkers_created.extend(list(workers.values()))
return workers

View File

@@ -25,7 +25,8 @@ from electrum.util import bfh, read_json_file, OldTaskGroup, get_asyncio_loop
from electrum.logging import console_stderr_handler
from . import ElectrumTestCase
from .test_lnpeer import keypair, MockLNWallet
from .test_lnpeer import create_mock_lnwallet
TIME_STEP = 0.01 # run tests 100 x faster
OnionMessageManager.SLEEP_DELAY *= TIME_STEP
@@ -352,7 +353,7 @@ class TestOnionMessageManager(ElectrumTestCase):
async def test_request_and_reply(self):
n = MockNetwork()
lnw = MockLNWallet(name='test_request_and_reply', has_anchors=False)
lnw = create_mock_lnwallet(name='test_request_and_reply', has_anchors=False)
def slow(*args, **kwargs):
time.sleep(2*TIME_STEP)
@@ -398,7 +399,7 @@ class TestOnionMessageManager(ElectrumTestCase):
async def test_forward(self):
n = MockNetwork()
lnw = MockLNWallet(name='alice', has_anchors=False)
lnw = create_mock_lnwallet(name='alice', has_anchors=False)
lnw.node_keypair = self.alice
self.was_sent = False
@@ -435,7 +436,7 @@ class TestOnionMessageManager(ElectrumTestCase):
async def test_receive_unsolicited(self):
n = MockNetwork()
lnw = MockLNWallet(name='dave', has_anchors=False)
lnw = create_mock_lnwallet(name='dave', has_anchors=False)
lnw.node_keypair = self.dave
t = OnionMessageManager(lnw)