diff --git a/electrum/daemon.py b/electrum/daemon.py index 1d50ceb04..7b21b12c3 100644 --- a/electrum/daemon.py +++ b/electrum/daemon.py @@ -476,14 +476,28 @@ class Daemon(Logger): return func_wrapper @with_wallet_lock - def load_wallet(self, path, password, *, upgrade=False) -> Optional[Abstract_Wallet]: + def load_wallet( + self, + path, + password: Optional[str], + *, + upgrade: bool = False, + force_check_password: bool = False, + ) -> Optional[Abstract_Wallet]: + """ + force_check_password: if False, the password arg is only used if it needed to decrypt the storage. + if True, the password arg is always validated. + """ assert password != '' path = standardize_path(path) wallet_key = self._wallet_key_from_path(path) # wizard will be launched if we return if wallet := self._wallets.get(wallet_key): + if force_check_password: + wallet.check_password(password) return wallet - wallet = self._load_wallet(path, password, upgrade=upgrade, config=self.config) + wallet = self._load_wallet( + path, password, upgrade=upgrade, config=self.config, force_check_password=force_check_password) if self.network: wallet.start_network(self.network) elif wallet.lnworker: @@ -501,10 +515,11 @@ class Daemon(Logger): @profiler def _load_wallet( path, - password, + password: Optional[str], *, upgrade: bool = False, config: SimpleConfig, + force_check_password: bool = False, # if set, always validate password ) -> Optional[Abstract_Wallet]: path = standardize_path(path) storage = WalletStorage(path, allow_partial_writes=config.WALLET_PARTIAL_WRITES) @@ -519,6 +534,8 @@ class Daemon(Logger): if db.get_action(): raise WalletUnfinished(db) wallet = Wallet(db, config=config) + if force_check_password: + wallet.check_password(password) return wallet @with_wallet_lock @@ -546,7 +563,7 @@ class Daemon(Logger): def stop_wallet(self, path: str) -> bool: """Returns True iff a wallet was found.""" - # note: this must not be called from the event loop. # TODO raise if so + assert util.get_running_loop() != util.get_asyncio_loop(), 'must not be called from asyncio thread' fut = asyncio.run_coroutine_threadsafe(self._stop_wallet(path), self.asyncio_loop) return fut.result() diff --git a/electrum/gui/qml/qedaemon.py b/electrum/gui/qml/qedaemon.py index a4c135d9e..651c93fd1 100644 --- a/electrum/gui/qml/qedaemon.py +++ b/electrum/gui/qml/qedaemon.py @@ -201,7 +201,7 @@ class QEDaemon(AuthMixin, QObject): wallet_already_open = self.daemon.get_wallet(self._path) if wallet_already_open is not None: - wallet_already_open_password = QEWallet.getInstanceFor(wallet_already_open).password + password = QEWallet.getInstanceFor(wallet_already_open).password def load_wallet_task(): success = False @@ -209,7 +209,13 @@ class QEDaemon(AuthMixin, QObject): local_password = password # need this in local scope wallet = None try: - wallet = self.daemon.load_wallet(self._path, local_password, upgrade=True) + wallet = self.daemon.load_wallet( + self._path, + password=local_password, + upgrade=True, + # might have a keystore password, but unencrypted storage. we want to prompt for pw even then: + force_check_password=True, + ) except InvalidPassword: self.walletRequiresPassword.emit(self._name, self._path) except FileNotFoundError: @@ -224,11 +230,6 @@ class QEDaemon(AuthMixin, QObject): if wallet is None: return - if wallet_already_open is not None: - # wallet already open. daemon.load_wallet doesn't mind, but - # we need the correct current wallet password below - local_password = wallet_already_open_password - if self.daemon.config.WALLET_USE_SINGLE_PASSWORD: self._use_single_password = self.daemon.update_password_for_directory(old_password=local_password, new_password=local_password) self._password = local_password diff --git a/tests/test_daemon.py b/tests/test_daemon.py index 5cb6c619f..08f53cf47 100644 --- a/tests/test_daemon.py +++ b/tests/test_daemon.py @@ -228,3 +228,79 @@ class TestCommandsWithDaemon(DaemonTestCase): self.assertEqual(self.SEED, await cmds.getseed(wallet_path=wpath)) self.assertEqual(self.SEED, await cmds.getseed(wallet_path=basename)) self.assertEqual(self.SEED, await cmds.getseed(wallet=wallet)) + + +class TestLoadWallet(DaemonTestCase): + + async def test_simple_load(self): + path1 = self._restore_wallet_from_text("9dk", password=None) + wallet1 = self.daemon.load_wallet(path1, password=None) + await self.daemon._stop_wallet(path1) + + async def test_password_checks_for_no_password(self): + real_password = None + path1 = self._restore_wallet_from_text("9dk", password=real_password) + # load_wallet will not validate the password arg unless needed for storage.decrypt(): + wallet1 = self.daemon.load_wallet(path1, password="garbage") + await self.daemon._stop_wallet(path1) + # unless force_check_password is set: + with self.assertRaises(util.InvalidPassword): + wallet1 = self.daemon.load_wallet(path1, password="garbage", force_check_password=True) + + wallet1 = self.daemon.load_wallet(path1, password=real_password) + await self.daemon._stop_wallet(path1) + + wallet1 = self.daemon.load_wallet(path1, password=real_password, force_check_password=True) + await self.daemon._stop_wallet(path1) + + # load_wallet will not validate the password arg if wallet is already loaded, unless force_check_password + wallet1 = self.daemon.load_wallet(path1, password=real_password) + wallet1 = self.daemon.load_wallet(path1, password="garbage") + with self.assertRaises(util.InvalidPassword): + wallet1 = self.daemon.load_wallet(path1, password="garbage", force_check_password=True) + + + async def test_password_checks_for_ks_enc(self): + real_password = "1234" + path1 = self._restore_wallet_from_text("9dk", password=real_password, encrypt_file=False) + # load_wallet will not validate the password arg unless needed for storage.decrypt(): + wallet1 = self.daemon.load_wallet(path1, password="garbage") + await self.daemon._stop_wallet(path1) + # unless force_check_password is set: + with self.assertRaises(util.InvalidPassword): + wallet1 = self.daemon.load_wallet(path1, password="garbage", force_check_password=True) + + wallet1 = self.daemon.load_wallet(path1, password=real_password) + await self.daemon._stop_wallet(path1) + + wallet1 = self.daemon.load_wallet(path1, password=real_password, force_check_password=True) + await self.daemon._stop_wallet(path1) + + # load_wallet will not validate the password arg if wallet is already loaded, unless force_check_password + wallet1 = self.daemon.load_wallet(path1, password=real_password) + wallet1 = self.daemon.load_wallet(path1, password="garbage") + with self.assertRaises(util.InvalidPassword): + wallet1 = self.daemon.load_wallet(path1, password="garbage", force_check_password=True) + + + async def test_password_checks_for_sto_enc(self): + real_password = "1234" + path1 = self._restore_wallet_from_text("9dk", password=real_password, encrypt_file=True) + + with self.assertRaises(util.InvalidPassword): + wallet1 = self.daemon.load_wallet(path1, password="garbage") + + with self.assertRaises(util.InvalidPassword): + wallet1 = self.daemon.load_wallet(path1, password="garbage", force_check_password=True) + + wallet1 = self.daemon.load_wallet(path1, password=real_password) + await self.daemon._stop_wallet(path1) + + wallet1 = self.daemon.load_wallet(path1, password=real_password, force_check_password=True) + await self.daemon._stop_wallet(path1) + + # load_wallet will not validate the password arg if wallet is already loaded, unless force_check_password + wallet1 = self.daemon.load_wallet(path1, password=real_password) + wallet1 = self.daemon.load_wallet(path1, password="garbage") + with self.assertRaises(util.InvalidPassword): + wallet1 = self.daemon.load_wallet(path1, password="garbage", force_check_password=True)