1
0

Merge pull request #10043 from SomberNight/202507_keystore_cleanup

keystore.py: some clean-up
This commit is contained in:
ghost43
2025-07-31 15:34:10 +00:00
committed by GitHub
3 changed files with 73 additions and 75 deletions

View File

@@ -290,12 +290,12 @@ class BIP32Node(NamedTuple):
return hash_160(self.eckey.get_public_key_bytes(compressed=True))[0:4]
def xpub_type(x: str):
def xpub_type(x: str) -> str:
assert x is not None
return BIP32Node.from_xkey(x).xtype
def is_xpub(text):
def is_xpub(text: str) -> bool:
try:
node = BIP32Node.from_xkey(text)
return not node.is_private()
@@ -303,7 +303,7 @@ def is_xpub(text):
return False
def is_xprv(text):
def is_xprv(text: str) -> bool:
try:
node = BIP32Node.from_xkey(text)
return node.is_private()
@@ -311,7 +311,7 @@ def is_xprv(text):
return False
def xpub_from_xprv(xprv):
def xpub_from_xprv(xprv: str) -> str:
return BIP32Node.from_xkey(xprv).to_xpub()

View File

@@ -28,7 +28,7 @@ from unicodedata import normalize
import hashlib
import re
import copy
from typing import Tuple, TYPE_CHECKING, Union, Sequence, Optional, Dict, List, NamedTuple
from typing import Tuple, TYPE_CHECKING, Union, Sequence, Optional, Dict, List, NamedTuple, Any, Type
from functools import lru_cache, wraps
from abc import ABC, abstractmethod
@@ -105,7 +105,7 @@ class KeyStore(Logger, ABC):
return f'{self.type}'
@abstractmethod
def may_have_password(self):
def may_have_password(self) -> bool:
"""Returns whether the keystore can be encrypted with a password."""
pass
@@ -129,7 +129,7 @@ class KeyStore(Logger, ABC):
keypairs[pubkey] = derivation
return keypairs
def can_sign(self, tx: 'Transaction', *, ignore_watching_only=False) -> bool:
def can_sign(self, tx: 'Transaction', *, ignore_watching_only: bool = False) -> bool:
"""Returns whether this keystore could sign *something* in this tx."""
if not ignore_watching_only and self.is_watching_only():
return False
@@ -137,7 +137,7 @@ class KeyStore(Logger, ABC):
return False
return bool(self._get_tx_derivations(tx))
def can_sign_txin(self, txin: 'TxInput', *, ignore_watching_only=False) -> bool:
def can_sign_txin(self, txin: 'TxInput', *, ignore_watching_only: bool = False) -> bool:
"""Returns whether this keystore could sign this txin."""
if not ignore_watching_only and self.is_watching_only():
return False
@@ -149,7 +149,7 @@ class KeyStore(Logger, ABC):
return not self.is_watching_only()
@abstractmethod
def dump(self) -> dict:
def dump(self) -> dict[str, Any]:
pass
@abstractmethod
@@ -213,7 +213,7 @@ class KeyStore(Logger, ABC):
class Software_KeyStore(KeyStore):
def __init__(self, d):
def __init__(self, d: dict):
KeyStore.__init__(self)
self.pw_hash_version = d.get('pw_hash_version', 1)
if self.pw_hash_version not in SUPPORTED_PW_HASH_VERSIONS:
@@ -249,7 +249,7 @@ class Software_KeyStore(KeyStore):
tx.sign(keypairs)
@abstractmethod
def update_password(self, old_password, new_password):
def update_password(self, old_password, new_password) -> None:
pass
@abstractmethod
@@ -268,7 +268,7 @@ class Imported_KeyStore(Software_KeyStore):
type = 'imported'
def __init__(self, d):
def __init__(self, d: dict):
Software_KeyStore.__init__(self, d)
self.keypairs = d.get('keypairs', {}) # type: Dict[str, str]
@@ -290,7 +290,7 @@ class Imported_KeyStore(Software_KeyStore):
pubkey = list(self.keypairs.keys())[0]
self.get_private_key(pubkey, password)
def import_privkey(self, sec, password):
def import_privkey(self, sec: str, password) -> Tuple[str, str]:
txin_type, privkey, compressed = deserialize_privkey(sec)
pubkey = ecc.ECPrivkey(privkey).get_public_key_hex(compressed=compressed)
# re-serialize the key so the internal storage format is consistent
@@ -303,7 +303,7 @@ class Imported_KeyStore(Software_KeyStore):
self.keypairs[pubkey] = pw_encode(serialized_privkey, password, version=self.pw_hash_version)
return txin_type, pubkey
def delete_imported_key(self, key):
def delete_imported_key(self, key: str) -> None:
self.keypairs.pop(key)
def get_private_key(self, pubkey: str, password):
@@ -343,7 +343,7 @@ class Imported_KeyStore(Software_KeyStore):
class Deterministic_KeyStore(Software_KeyStore):
def __init__(self, d):
def __init__(self, d: dict):
Software_KeyStore.__init__(self, d)
self.seed = d.get('seed', '') # only electrum seeds
self.passphrase = d.get('passphrase', '')
@@ -384,12 +384,12 @@ class Deterministic_KeyStore(Software_KeyStore):
self.seed = self.format_seed(seed)
self._seed_type = calc_seed_type(seed) or None
def get_seed(self, password):
def get_seed(self, password) -> str:
if not self.has_seed():
raise Exception("This wallet has no seed words")
return pw_decode(self.seed, password, version=self.pw_hash_version)
def get_passphrase(self, password):
def get_passphrase(self, password) -> str:
if self.passphrase:
return pw_decode(self.passphrase, password, version=self.pw_hash_version)
else:
@@ -633,7 +633,7 @@ class BIP32_KeyStore(Xpub, Deterministic_KeyStore):
type = 'bip32'
def __init__(self, d):
def __init__(self, d: dict):
Xpub.__init__(self, derivation_prefix=d.get('derivation'), root_fingerprint=d.get('root_fingerprint'))
Deterministic_KeyStore.__init__(self, d)
self.xpub = d.get('xpub')
@@ -653,7 +653,7 @@ class BIP32_KeyStore(Xpub, Deterministic_KeyStore):
d['root_fingerprint'] = self.get_root_fingerprint()
return d
def get_master_private_key(self, password):
def get_master_private_key(self, password) -> str:
return pw_decode(self.xprv, password, version=self.pw_hash_version)
@also_test_none_password
@@ -707,11 +707,6 @@ class BIP32_KeyStore(Xpub, Deterministic_KeyStore):
pk = node.eckey.get_secret_bytes()
return pk, True
def get_keypair(self, sequence, password):
k, _ = self.get_private_key(sequence, password)
cK = ecc.ECPrivkey(k).get_public_key_bytes()
return cK, k
def can_have_deterministic_lightning_xprv(self):
if (self.get_seed_type() == 'segwit'
and self.get_bip32_node_for_xpub().xtype == 'p2wpkh'):
@@ -729,16 +724,18 @@ class Old_KeyStore(MasterPublicKeyMixin, Deterministic_KeyStore):
type = 'old'
def __init__(self, d):
def __init__(self, d: dict):
Deterministic_KeyStore.__init__(self, d)
self.mpk = d.get('mpk')
self.mpk = d.get('mpk') # type: Optional[str]
self._root_fingerprint = None
def watching_only_keystore(self):
return Old_KeyStore({'mpk': self.mpk})
def get_hex_seed(self, password):
return pw_decode(self.seed, password, version=self.pw_hash_version).encode('utf8')
def _get_hex_seed(self, password) -> str:
hex_str = pw_decode(self.seed, password, version=self.pw_hash_version)
assert is_hex_str(hex_str), f"expected hex str, got {type(hex_str)} with {len(hex_str)=}"
return hex_str
def dump(self):
d = Deterministic_KeyStore.dump(self)
@@ -747,10 +744,10 @@ class Old_KeyStore(MasterPublicKeyMixin, Deterministic_KeyStore):
def add_seed(self, seed):
Deterministic_KeyStore.add_seed(self, seed)
s = self.get_hex_seed(None)
self.mpk = self.mpk_from_seed(s)
hex_seed = self._get_hex_seed(None)
self.mpk = self.mpk_from_seed(hex_seed)
def add_master_public_key(self, mpk) -> None:
def add_master_public_key(self, mpk: str) -> None:
self.mpk = mpk
def format_seed(self, seed):
@@ -771,28 +768,30 @@ class Old_KeyStore(MasterPublicKeyMixin, Deterministic_KeyStore):
def get_seed(self, password):
from . import old_mnemonic
s = self.get_hex_seed(password)
return ' '.join(old_mnemonic.mn_encode(s))
hex_seed = self._get_hex_seed(password)
return ' '.join(old_mnemonic.mn_encode(hex_seed))
@classmethod
def mpk_from_seed(klass, seed):
secexp = klass.stretch_key(seed)
def mpk_from_seed(cls, hex_seed: str) -> str:
secexp = cls.stretch_key(hex_seed)
privkey = ecc.ECPrivkey.from_secret_scalar(secexp)
return privkey.get_public_key_hex(compressed=False)[2:]
@classmethod
def stretch_key(self, seed):
x = seed
def stretch_key(cls, hex_seed: str) -> int:
assert is_hex_str(hex_seed), f"expected hex str, got {type(hex_seed)} with {len(hex_seed)=}"
encoded_hex_seed = hex_seed.encode('ascii')
x = encoded_hex_seed
for i in range(100000):
x = hashlib.sha256(x + seed).digest()
x = hashlib.sha256(x + encoded_hex_seed).digest()
return string_to_number(x)
@classmethod
def get_sequence(self, mpk, for_change, n):
def get_sequence(cls, mpk: str, for_change: int, n: int) -> int:
return string_to_number(sha256d(("%d:%d:"%(n, for_change)).encode('ascii') + bfh(mpk)))
@classmethod
def get_pubkey_from_mpk(cls, mpk, for_change, n) -> bytes:
def get_pubkey_from_mpk(cls, mpk: str, for_change: int, n: int) -> bytes:
z = cls.get_sequence(mpk, for_change, n)
master_public_key = ecc.ECPubkey(bfh('04'+mpk))
public_key = master_public_key + z*ecc.GENERATOR
@@ -805,22 +804,24 @@ class Old_KeyStore(MasterPublicKeyMixin, Deterministic_KeyStore):
raise CannotDerivePubkey("forbidden path")
return self.get_pubkey_from_mpk(self.mpk, for_change, n)
def _get_private_key_from_stretched_exponent(self, for_change, n, secexp):
def _get_private_key_from_stretched_exponent(self, for_change: int, n: int, secexp: int) -> bytes:
secexp = (secexp + self.get_sequence(self.mpk, for_change, n)) % ecc.CURVE_ORDER
pk = int.to_bytes(secexp, length=32, byteorder='big', signed=False)
return pk
def get_private_key(self, sequence: Sequence[int], password):
seed = self.get_hex_seed(password)
secexp = self.stretch_key(seed)
self._check_seed(seed, secexp=secexp)
hex_seed = self._get_hex_seed(password)
secexp = self.stretch_key(hex_seed)
self._check_seed(hex_seed, secexp=secexp)
for_change, n = sequence
assert isinstance(for_change, int), type(for_change)
assert isinstance(n, int), type(n)
pk = self._get_private_key_from_stretched_exponent(for_change, n, secexp)
return pk, False
def _check_seed(self, seed, *, secexp=None):
def _check_seed(self, hex_seed: str, *, secexp: int = None) -> None:
if secexp is None:
secexp = self.stretch_key(seed)
secexp = self.stretch_key(hex_seed)
master_private_key = ecc.ECPrivkey.from_secret_scalar(secexp)
master_public_key = master_private_key.get_public_key_bytes(compressed=False)[1:]
if master_public_key != bfh(self.mpk):
@@ -828,8 +829,8 @@ class Old_KeyStore(MasterPublicKeyMixin, Deterministic_KeyStore):
@also_test_none_password
def check_password(self, password):
seed = self.get_hex_seed(password)
self._check_seed(seed)
hex_seed = self._get_hex_seed(password)
self._check_seed(hex_seed)
def get_master_public_key(self):
return self.mpk
@@ -894,7 +895,7 @@ class Hardware_KeyStore(Xpub, KeyStore):
self.handler = None # type: Optional[HardwareHandlerBase]
run_hook('init_keystore', self)
def set_label(self, label):
def set_label(self, label: Optional[str]) -> None:
self.label = label
def may_have_password(self):
@@ -917,16 +918,6 @@ class Hardware_KeyStore(Xpub, KeyStore):
'soft_device_id': self.soft_device_id,
}
def unpaired(self):
'''A device paired with the wallet was disconnected. This can be
called in any thread context.'''
self.logger.info("unpaired")
def paired(self):
'''A device paired with the wallet was (re-)connected. This can be
called in any thread context.'''
self.logger.info("paired")
def is_watching_only(self):
'''The wallet is not watching-only; the user will be prompted for
pin and passphrase as appropriate when needed.'''
@@ -1079,9 +1070,9 @@ def xtype_from_derivation(derivation: str) -> str:
return 'standard'
hw_keystores = {}
hw_keystores = {} # type: Dict[str, Type[Hardware_KeyStore]]
def register_keystore(hw_type, constructor):
def register_keystore(hw_type: str, constructor: Type[Hardware_KeyStore]) -> None:
hw_keystores[hw_type] = constructor
def hardware_keystore(d) -> Hardware_KeyStore:
@@ -1125,12 +1116,12 @@ def is_old_mpk(mpk: str) -> bool:
return True
def is_address_list(text):
def is_address_list(text: str) -> bool:
parts = text.split()
return bool(parts) and all(bitcoin.is_address(x) for x in parts)
def get_private_keys(text, *, allow_spaces_inside_key=True, raise_on_error=False):
def get_private_keys(text: str, *, allow_spaces_inside_key=True, raise_on_error=False) -> Optional[Sequence[str]]:
if allow_spaces_inside_key: # see #1612
parts = text.split('\n')
parts = map(lambda x: ''.join(x.split()), parts)
@@ -1139,23 +1130,24 @@ def get_private_keys(text, *, allow_spaces_inside_key=True, raise_on_error=False
parts = text.split()
if bool(parts) and all(bitcoin.is_private_key(x, raise_on_error=raise_on_error) for x in parts):
return parts
return None
def is_private_key_list(text, *, allow_spaces_inside_key=True, raise_on_error=False):
def is_private_key_list(text: str, *, allow_spaces_inside_key: bool = True, raise_on_error: bool = False) -> bool:
return bool(get_private_keys(text,
allow_spaces_inside_key=allow_spaces_inside_key,
raise_on_error=raise_on_error))
def is_master_key(x):
def is_master_key(x: str) -> bool:
return is_old_mpk(x) or is_bip32_key(x)
def is_bip32_key(x):
def is_bip32_key(x: str) -> bool:
return is_xprv(x) or is_xpub(x)
def bip44_derivation(account_id, bip43_purpose=44):
def bip44_derivation(account_id: int, bip43_purpose: int = 44) -> str:
coin = constants.net.BIP44_COIN_TYPE
der = "m/%d'/%d'/%d'" % (bip43_purpose, coin, int(account_id))
return normalize_bip32_derivation(der)
@@ -1173,7 +1165,7 @@ def purpose48_derivation(account_id: int, xtype: str) -> str:
return normalize_bip32_derivation(der)
def from_seed(seed: str, *, passphrase: Optional[str], for_multisig: bool = False):
def from_seed(seed: str, *, passphrase: Optional[str], for_multisig: bool = False) -> Union[BIP32_KeyStore, Old_KeyStore]:
passphrase = passphrase or ""
t = calc_seed_type(seed)
if t == 'old':
@@ -1197,28 +1189,28 @@ def from_seed(seed: str, *, passphrase: Optional[str], for_multisig: bool = Fals
raise BitcoinException('Unexpected seed type {}'.format(repr(t)))
return keystore
def from_private_key_list(text):
def from_private_key_list(text: str) -> Imported_KeyStore:
keystore = Imported_KeyStore({})
for x in get_private_keys(text):
keystore.import_privkey(x, None)
return keystore
def from_old_mpk(mpk):
def from_old_mpk(mpk: str) -> Old_KeyStore:
keystore = Old_KeyStore({})
keystore.add_master_public_key(mpk)
return keystore
def from_xpub(xpub):
def from_xpub(xpub: str) -> BIP32_KeyStore:
k = BIP32_KeyStore({})
k.add_xpub(xpub)
return k
def from_xprv(xprv):
def from_xprv(xprv: str) -> BIP32_KeyStore:
k = BIP32_KeyStore({})
k.add_xprv(xprv)
return k
def from_master_key(text):
def from_master_key(text: str) -> Union[BIP32_KeyStore, Old_KeyStore]:
if is_xprv(text):
k = from_xprv(text)
elif is_old_mpk(text):

View File

@@ -23,7 +23,10 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import Sequence, Union
from .mnemonic import Wordlist
from .util import is_hex_str
# list of words from http://en.wiktionary.org/wiki/Wiktionary:Frequency_lists/Contemporary_poetry
@@ -1666,7 +1669,10 @@ assert n == 1626
# Note about US patent no 5892470: Here each word does not represent a given digit.
# Instead, the digit represented by a word is variable, it depends on the previous word.
def mn_encode(message):
def mn_encode(message: str) -> Sequence[str]:
# note: to generate an 'old'-type mnemonic for testing:
# " ".join(electrum.old_mnemonic.mn_encode(secrets.token_hex(16)))
assert is_hex_str(message), f"expected hex, got {type(message)}"
assert len(message) % 8 == 0
out = []
for i in range(len(message)//8):
@@ -1679,7 +1685,7 @@ def mn_encode(message):
return out
def mn_decode(wlist):
def mn_decode(wlist: Sequence[str]) -> str:
out = ''
for i in range(len(wlist)//3):
word1, word2, word3 = wlist[3*i:3*i+3]