diff --git a/electrum/wallet.py b/electrum/wallet.py index 5071325f1..cf7874ff0 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -4262,7 +4262,7 @@ class Wallet(object): def __new__(cls, db: 'WalletDB', *, config: SimpleConfig) -> Abstract_Wallet: wallet_type = db.get('wallet_type') - WalletClass = Wallet.wallet_class(wallet_type) + WalletClass = cls.wallet_class(wallet_type) wallet = WalletClass(db, config=config) return wallet @@ -4320,6 +4320,7 @@ def restore_wallet_from_text( encrypt_file: Optional[bool] = None, gap_limit: Optional[int] = None, gap_limit_for_change: Optional[int] = None, + wallet_factory = Wallet, # used in tests ) -> dict: """Restore a wallet from text. Text can be a seed phrase, a master public key, a master private key, a list of bitcoin addresses @@ -4365,7 +4366,7 @@ def restore_wallet_from_text( db.put('gap_limit', gap_limit) if gap_limit_for_change is not None: db.put('gap_limit_for_change', gap_limit_for_change) - wallet = Wallet(db, config=config) + wallet = wallet_factory(db, config=config) if db.storage: assert not db.storage.file_exists(), "file was created too soon! plaintext keys might have been written to disk" wallet.update_password(old_pw=None, new_pw=password, encrypt_storage=encrypt_file) diff --git a/tests/test_lnchannel.py b/tests/test_lnchannel.py index 8aebab7c6..0b3683345 100644 --- a/tests/test_lnchannel.py +++ b/tests/test_lnchannel.py @@ -175,12 +175,12 @@ def create_test_channels( remote_max_inflight = funding_sat * 1000 if remote_max_inflight is None else remote_max_inflight alice_raw = [bip32("m/" + str(i)) for i in range(5)] bob_raw = [bip32("m/" + str(i)) for i in range(5,11)] - alice_privkeys = [lnutil.Keypair(privkey_to_pubkey(x), x) for x in alice_raw] + alice_privkeys = [lnutil.Keypair(privkey_to_pubkey(x), x) for x in alice_raw] # TODO make it depend on alice_lnwallet bob_privkeys = [lnutil.Keypair(privkey_to_pubkey(x), x) for x in bob_raw] alice_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in alice_privkeys] bob_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in bob_privkeys] - alice_seed = random_gen.get_bytes(32) + alice_seed = random_gen.get_bytes(32) # TODO make it depend on alice_lnwallet bob_seed = random_gen.get_bytes(32) alice_first = lnutil.secret_to_pubkey( @@ -201,7 +201,9 @@ def create_test_channels( max_accepted_htlcs=max_accepted_htlcs, ), name=f"{alice_name}->{bob_name}", - initial_feerate=feerate), + initial_feerate=feerate, + lnworker=alice_lnwallet, + ), lnchannel.Channel( create_channel_state( funding_txid, funding_index, funding_sat, False, remote_amount, @@ -212,7 +214,9 @@ def create_test_channels( max_accepted_htlcs=max_accepted_htlcs, ), name=f"{bob_name}->{alice_name}", - initial_feerate=feerate) + initial_feerate=feerate, + lnworker=bob_lnwallet, + ) ) alice.hm.log[LOCAL]['ctn'] = 0 diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 1e1c64143..134c6a803 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -51,7 +51,7 @@ from electrum.interface import GracefulDisconnect from electrum.simple_config import SimpleConfig from electrum.fee_policy import FeeTimeEstimates, FEE_ETA_TARGETS from electrum.mpp_split import split_amount_normal -from electrum.wallet import Abstract_Wallet +from electrum.wallet import Abstract_Wallet, Standard_Wallet from .test_lnchannel import create_test_channels from .test_bitcoin import needs_test_with_all_chacha20_implementations @@ -116,93 +116,70 @@ class MockLNGossip: return None, None, None -class MockLNPeerManager(LNPeerManager): - def __init__( - self, - *, - node_keypair, - config: SimpleConfig, - features: LnFeatures, - lnwallet: LNWallet, - network: 'MockNetwork', - ): - LNPeerManager.__init__( - self, - node_keypair=node_keypair, - lnwallet_or_lngossip=lnwallet, - features=features, - config=config, - ) - self.network = network +class MockWalletFactory(electrum.wallet.Wallet): + + @staticmethod + def wallet_class(wallet_type): + real_wallet_class = electrum.wallet.Wallet.wallet_class(wallet_type) + if real_wallet_class is Standard_Wallet: + return MockStandardWallet + return real_wallet_class -@lru_cache() -def _bip32_from_name(name: str) -> bip32.BIP32Node: - # note: unlike a serialized xprv, the bip32 node can be cached easily, - # as it does not depend on constant.net (testnet/mainnet) network bytes - sequence = [ord(c) for c in name] - bip32_node = bip32.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence) - return bip32_node +class MockStandardWallet(Standard_Wallet): + def _init_lnworker(self): + ln_xprv = self.db.get('lightning_xprv') or self.db.get('lightning_privkey2') + assert ln_xprv + self.lnworker = MockLNWallet(self, ln_xprv) + def basename(self): + passphrase = self.db.get("keystore").get("passphrase") + assert passphrase + return passphrase # lol, super secure name + +def create_mock_lnwallet(*, name, has_anchors) -> 'MockLNWallet': + _user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-") + config = SimpleConfig({}, read_user_dir_function=lambda: _user_dir) + config.ENABLE_ANCHOR_CHANNELS = has_anchors + config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0 + + network = MockNetwork(config=config) + + wallet = restore_wallet_from_text__for_unittest( + "9dk", path=None, passphrase=name, config=config, + wallet_factory=MockWalletFactory, + )['wallet'] # type: MockStandardWallet + wallet.is_up_to_date = lambda: True + wallet.adb.network = wallet.network = network + + lnworker = wallet.lnworker + assert isinstance(lnworker, MockLNWallet), f"{lnworker=!r}" + lnworker._user_dir = _user_dir + lnworker.lnpeermgr.network = network + lnworker.logger.info(f"created LNWallet[{name}] with nodeID={lnworker.node_keypair.pubkey.hex()}") + return lnworker class MockLNWallet(LNWallet): MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1 TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0 MPP_SPLIT_PART_FRACTION = 1 # this disables the forced splitting - def __init__(self, *, name, has_anchors): - self.name = name - - self._user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-") - self.config = SimpleConfig({}, read_user_dir_function=lambda: self._user_dir) - self.config.ENABLE_ANCHOR_CHANNELS = has_anchors - self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0 - - network = MockNetwork(config=self.config) - - wallet = restore_wallet_from_text__for_unittest( - "9dk", path=None, passphrase=name, config=self.config)['wallet'] # type: Abstract_Wallet - wallet.is_up_to_date = lambda: True - wallet.adb.network = wallet.network = network - #assert wallet.lnworker is None # FIXME xxxxx wallet already has another lnworker by now >.< - - features = LnFeatures(0) - features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT - features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT - features |= LnFeatures.VAR_ONION_OPT - features |= LnFeatures.PAYMENT_SECRET_OPT - features |= LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM - features |= LnFeatures.OPTION_CHANNEL_TYPE_OPT - features |= LnFeatures.OPTION_SCID_ALIAS_OPT - features |= LnFeatures.OPTION_STATIC_REMOTEKEY_OPT - - ln_xprv = _bip32_from_name(name).to_xprv() - LNWallet.__init__(self, wallet=wallet, xprv=ln_xprv, features=features) - - self.lnpeermgr = MockLNPeerManager( - node_keypair=self.node_keypair, - config=self.config, - features=features, - lnwallet=self, - network=network, - ) - - self.logger.info(f"created LNWallet[{name}] with nodeID={self.node_keypair.pubkey.hex()}") + def __init__(self, *args, **kwargs): + LNWallet.__init__(self, *args, **kwargs) + self.features &= ~LnFeatures.BASIC_MPP_OPT # by default, disable MPP def _add_channel(self, chan: Channel): self._channels[chan.channel_id] = chan + # assert chan.lnworker == self # this fails as some tests are reusing chans in a weird way chan.lnworker = self @LNWallet.features.setter def features(self, value): self.lnpeermgr.features = value - def save_channel(self, chan): - pass - #print("Ignoring channel save") - - def diagnostic_name(self): - return self.name + @property + def name(self): + return self.wallet.basename() async def stop(self): await LNWallet.stop(self) @@ -524,7 +501,7 @@ class TestPeer(ElectrumTestCase): def prepare_lnwallets(self, graph_definition) -> Mapping[str, MockLNWallet]: workers = {} # type: Dict[str, MockLNWallet] for a, definition in graph_definition.items(): - workers[a] = MockLNWallet(name=a, has_anchors=self.TEST_ANCHOR_CHANNELS) + workers[a] = create_mock_lnwallet(name=a, has_anchors=self.TEST_ANCHOR_CHANNELS) self._lnworkers_created.extend(list(workers.values())) return workers diff --git a/tests/test_onion_message.py b/tests/test_onion_message.py index f33e2368f..058e6b802 100644 --- a/tests/test_onion_message.py +++ b/tests/test_onion_message.py @@ -25,7 +25,8 @@ from electrum.util import bfh, read_json_file, OldTaskGroup, get_asyncio_loop from electrum.logging import console_stderr_handler from . import ElectrumTestCase -from .test_lnpeer import keypair, MockLNWallet +from .test_lnpeer import create_mock_lnwallet + TIME_STEP = 0.01 # run tests 100 x faster OnionMessageManager.SLEEP_DELAY *= TIME_STEP @@ -352,7 +353,7 @@ class TestOnionMessageManager(ElectrumTestCase): async def test_request_and_reply(self): n = MockNetwork() - lnw = MockLNWallet(name='test_request_and_reply', has_anchors=False) + lnw = create_mock_lnwallet(name='test_request_and_reply', has_anchors=False) def slow(*args, **kwargs): time.sleep(2*TIME_STEP) @@ -398,7 +399,7 @@ class TestOnionMessageManager(ElectrumTestCase): async def test_forward(self): n = MockNetwork() - lnw = MockLNWallet(name='alice', has_anchors=False) + lnw = create_mock_lnwallet(name='alice', has_anchors=False) lnw.node_keypair = self.alice self.was_sent = False @@ -435,7 +436,7 @@ class TestOnionMessageManager(ElectrumTestCase): async def test_receive_unsolicited(self): n = MockNetwork() - lnw = MockLNWallet(name='dave', has_anchors=False) + lnw = create_mock_lnwallet(name='dave', has_anchors=False) lnw.node_keypair = self.dave t = OnionMessageManager(lnw)