1
0

keystore: add more type hints

wtf is going on with the type of hex_seed

and mn_encode is polymorphic in a really ugly way...
This commit is contained in:
SomberNight
2025-07-18 00:13:00 +00:00
parent f6db5fd77c
commit a257072391
3 changed files with 61 additions and 48 deletions

View File

@@ -290,12 +290,12 @@ class BIP32Node(NamedTuple):
return hash_160(self.eckey.get_public_key_bytes(compressed=True))[0:4]
def xpub_type(x: str):
def xpub_type(x: str) -> str:
assert x is not None
return BIP32Node.from_xkey(x).xtype
def is_xpub(text):
def is_xpub(text: str) -> bool:
try:
node = BIP32Node.from_xkey(text)
return not node.is_private()
@@ -303,7 +303,7 @@ def is_xpub(text):
return False
def is_xprv(text):
def is_xprv(text: str) -> bool:
try:
node = BIP32Node.from_xkey(text)
return node.is_private()
@@ -311,7 +311,7 @@ def is_xprv(text):
return False
def xpub_from_xprv(xprv):
def xpub_from_xprv(xprv: str) -> str:
return BIP32Node.from_xkey(xprv).to_xpub()

View File

@@ -28,7 +28,7 @@ from unicodedata import normalize
import hashlib
import re
import copy
from typing import Tuple, TYPE_CHECKING, Union, Sequence, Optional, Dict, List, NamedTuple
from typing import Tuple, TYPE_CHECKING, Union, Sequence, Optional, Dict, List, NamedTuple, Any, Type
from functools import lru_cache, wraps
from abc import ABC, abstractmethod
@@ -105,7 +105,7 @@ class KeyStore(Logger, ABC):
return f'{self.type}'
@abstractmethod
def may_have_password(self):
def may_have_password(self) -> bool:
"""Returns whether the keystore can be encrypted with a password."""
pass
@@ -129,7 +129,7 @@ class KeyStore(Logger, ABC):
keypairs[pubkey] = derivation
return keypairs
def can_sign(self, tx: 'Transaction', *, ignore_watching_only=False) -> bool:
def can_sign(self, tx: 'Transaction', *, ignore_watching_only: bool = False) -> bool:
"""Returns whether this keystore could sign *something* in this tx."""
if not ignore_watching_only and self.is_watching_only():
return False
@@ -137,7 +137,7 @@ class KeyStore(Logger, ABC):
return False
return bool(self._get_tx_derivations(tx))
def can_sign_txin(self, txin: 'TxInput', *, ignore_watching_only=False) -> bool:
def can_sign_txin(self, txin: 'TxInput', *, ignore_watching_only: bool = False) -> bool:
"""Returns whether this keystore could sign this txin."""
if not ignore_watching_only and self.is_watching_only():
return False
@@ -149,7 +149,7 @@ class KeyStore(Logger, ABC):
return not self.is_watching_only()
@abstractmethod
def dump(self) -> dict:
def dump(self) -> dict[str, Any]:
pass
@abstractmethod
@@ -213,7 +213,7 @@ class KeyStore(Logger, ABC):
class Software_KeyStore(KeyStore):
def __init__(self, d):
def __init__(self, d: dict):
KeyStore.__init__(self)
self.pw_hash_version = d.get('pw_hash_version', 1)
if self.pw_hash_version not in SUPPORTED_PW_HASH_VERSIONS:
@@ -249,7 +249,7 @@ class Software_KeyStore(KeyStore):
tx.sign(keypairs)
@abstractmethod
def update_password(self, old_password, new_password):
def update_password(self, old_password, new_password) -> None:
pass
@abstractmethod
@@ -268,7 +268,7 @@ class Imported_KeyStore(Software_KeyStore):
type = 'imported'
def __init__(self, d):
def __init__(self, d: dict):
Software_KeyStore.__init__(self, d)
self.keypairs = d.get('keypairs', {}) # type: Dict[str, str]
@@ -290,7 +290,7 @@ class Imported_KeyStore(Software_KeyStore):
pubkey = list(self.keypairs.keys())[0]
self.get_private_key(pubkey, password)
def import_privkey(self, sec, password):
def import_privkey(self, sec: str, password) -> Tuple[str, str]:
txin_type, privkey, compressed = deserialize_privkey(sec)
pubkey = ecc.ECPrivkey(privkey).get_public_key_hex(compressed=compressed)
# re-serialize the key so the internal storage format is consistent
@@ -303,7 +303,7 @@ class Imported_KeyStore(Software_KeyStore):
self.keypairs[pubkey] = pw_encode(serialized_privkey, password, version=self.pw_hash_version)
return txin_type, pubkey
def delete_imported_key(self, key):
def delete_imported_key(self, key: str) -> None:
self.keypairs.pop(key)
def get_private_key(self, pubkey: str, password):
@@ -343,7 +343,7 @@ class Imported_KeyStore(Software_KeyStore):
class Deterministic_KeyStore(Software_KeyStore):
def __init__(self, d):
def __init__(self, d: dict):
Software_KeyStore.__init__(self, d)
self.seed = d.get('seed', '') # only electrum seeds
self.passphrase = d.get('passphrase', '')
@@ -384,12 +384,12 @@ class Deterministic_KeyStore(Software_KeyStore):
self.seed = self.format_seed(seed)
self._seed_type = calc_seed_type(seed) or None
def get_seed(self, password):
def get_seed(self, password) -> str:
if not self.has_seed():
raise Exception("This wallet has no seed words")
return pw_decode(self.seed, password, version=self.pw_hash_version)
def get_passphrase(self, password):
def get_passphrase(self, password) -> str:
if self.passphrase:
return pw_decode(self.passphrase, password, version=self.pw_hash_version)
else:
@@ -633,7 +633,7 @@ class BIP32_KeyStore(Xpub, Deterministic_KeyStore):
type = 'bip32'
def __init__(self, d):
def __init__(self, d: dict):
Xpub.__init__(self, derivation_prefix=d.get('derivation'), root_fingerprint=d.get('root_fingerprint'))
Deterministic_KeyStore.__init__(self, d)
self.xpub = d.get('xpub')
@@ -653,7 +653,7 @@ class BIP32_KeyStore(Xpub, Deterministic_KeyStore):
d['root_fingerprint'] = self.get_root_fingerprint()
return d
def get_master_private_key(self, password):
def get_master_private_key(self, password) -> str:
return pw_decode(self.xprv, password, version=self.pw_hash_version)
@also_test_none_password
@@ -724,16 +724,18 @@ class Old_KeyStore(MasterPublicKeyMixin, Deterministic_KeyStore):
type = 'old'
def __init__(self, d):
def __init__(self, d: dict):
Deterministic_KeyStore.__init__(self, d)
self.mpk = d.get('mpk')
self.mpk = d.get('mpk') # type: Optional[str]
self._root_fingerprint = None
def watching_only_keystore(self):
return Old_KeyStore({'mpk': self.mpk})
def get_hex_seed(self, password):
return pw_decode(self.seed, password, version=self.pw_hash_version).encode('utf8')
def get_hex_seed(self, password) -> bytes:
# FIXME we return bytes that only contain hex characters.
hex_str = pw_decode(self.seed, password, version=self.pw_hash_version)
return hex_str.encode('utf8')
def dump(self):
d = Deterministic_KeyStore.dump(self)
@@ -745,7 +747,7 @@ class Old_KeyStore(MasterPublicKeyMixin, Deterministic_KeyStore):
s = self.get_hex_seed(None)
self.mpk = self.mpk_from_seed(s)
def add_master_public_key(self, mpk) -> None:
def add_master_public_key(self, mpk: str) -> None:
self.mpk = mpk
def format_seed(self, seed):
@@ -770,24 +772,27 @@ class Old_KeyStore(MasterPublicKeyMixin, Deterministic_KeyStore):
return ' '.join(old_mnemonic.mn_encode(s))
@classmethod
def mpk_from_seed(klass, seed):
secexp = klass.stretch_key(seed)
def mpk_from_seed(cls, seed: bytes) -> str:
# FIXME `seed` is bytes that only contain hex characters.
secexp = cls.stretch_key(seed)
privkey = ecc.ECPrivkey.from_secret_scalar(secexp)
return privkey.get_public_key_hex(compressed=False)[2:]
@classmethod
def stretch_key(self, seed):
def stretch_key(cls, seed: bytes) -> int:
# FIXME `seed` is bytes that only contain hex characters.
assert isinstance(seed, bytes), f"expected bytes, got {type(seed)}"
x = seed
for i in range(100000):
x = hashlib.sha256(x + seed).digest()
return string_to_number(x)
@classmethod
def get_sequence(self, mpk, for_change, n):
def get_sequence(cls, mpk: str, for_change: int, n: int) -> int:
return string_to_number(sha256d(("%d:%d:"%(n, for_change)).encode('ascii') + bfh(mpk)))
@classmethod
def get_pubkey_from_mpk(cls, mpk, for_change, n) -> bytes:
def get_pubkey_from_mpk(cls, mpk: str, for_change: int, n: int) -> bytes:
z = cls.get_sequence(mpk, for_change, n)
master_public_key = ecc.ECPubkey(bfh('04'+mpk))
public_key = master_public_key + z*ecc.GENERATOR
@@ -800,7 +805,7 @@ class Old_KeyStore(MasterPublicKeyMixin, Deterministic_KeyStore):
raise CannotDerivePubkey("forbidden path")
return self.get_pubkey_from_mpk(self.mpk, for_change, n)
def _get_private_key_from_stretched_exponent(self, for_change, n, secexp):
def _get_private_key_from_stretched_exponent(self, for_change: int, n: int, secexp: int) -> bytes:
secexp = (secexp + self.get_sequence(self.mpk, for_change, n)) % ecc.CURVE_ORDER
pk = int.to_bytes(secexp, length=32, byteorder='big', signed=False)
return pk
@@ -813,7 +818,8 @@ class Old_KeyStore(MasterPublicKeyMixin, Deterministic_KeyStore):
pk = self._get_private_key_from_stretched_exponent(for_change, n, secexp)
return pk, False
def _check_seed(self, seed, *, secexp=None):
def _check_seed(self, seed: bytes, *, secexp: int = None) -> None:
# FIXME `seed` is bytes that only contain hex characters.
if secexp is None:
secexp = self.stretch_key(seed)
master_private_key = ecc.ECPrivkey.from_secret_scalar(secexp)
@@ -889,7 +895,7 @@ class Hardware_KeyStore(Xpub, KeyStore):
self.handler = None # type: Optional[HardwareHandlerBase]
run_hook('init_keystore', self)
def set_label(self, label):
def set_label(self, label: Optional[str]) -> None:
self.label = label
def may_have_password(self):
@@ -1064,9 +1070,9 @@ def xtype_from_derivation(derivation: str) -> str:
return 'standard'
hw_keystores = {}
hw_keystores = {} # type: Dict[str, Type[Hardware_KeyStore]]
def register_keystore(hw_type, constructor):
def register_keystore(hw_type: str, constructor: Type[Hardware_KeyStore]) -> None:
hw_keystores[hw_type] = constructor
def hardware_keystore(d) -> Hardware_KeyStore:
@@ -1110,12 +1116,12 @@ def is_old_mpk(mpk: str) -> bool:
return True
def is_address_list(text):
def is_address_list(text: str) -> bool:
parts = text.split()
return bool(parts) and all(bitcoin.is_address(x) for x in parts)
def get_private_keys(text, *, allow_spaces_inside_key=True, raise_on_error=False):
def get_private_keys(text: str, *, allow_spaces_inside_key=True, raise_on_error=False) -> Optional[Sequence[str]]:
if allow_spaces_inside_key: # see #1612
parts = text.split('\n')
parts = map(lambda x: ''.join(x.split()), parts)
@@ -1124,23 +1130,24 @@ def get_private_keys(text, *, allow_spaces_inside_key=True, raise_on_error=False
parts = text.split()
if bool(parts) and all(bitcoin.is_private_key(x, raise_on_error=raise_on_error) for x in parts):
return parts
return None
def is_private_key_list(text, *, allow_spaces_inside_key=True, raise_on_error=False):
def is_private_key_list(text: str, *, allow_spaces_inside_key: bool = True, raise_on_error: bool = False) -> bool:
return bool(get_private_keys(text,
allow_spaces_inside_key=allow_spaces_inside_key,
raise_on_error=raise_on_error))
def is_master_key(x):
def is_master_key(x: str) -> bool:
return is_old_mpk(x) or is_bip32_key(x)
def is_bip32_key(x):
def is_bip32_key(x: str) -> bool:
return is_xprv(x) or is_xpub(x)
def bip44_derivation(account_id, bip43_purpose=44):
def bip44_derivation(account_id: int, bip43_purpose: int = 44) -> str:
coin = constants.net.BIP44_COIN_TYPE
der = "m/%d'/%d'/%d'" % (bip43_purpose, coin, int(account_id))
return normalize_bip32_derivation(der)
@@ -1158,7 +1165,7 @@ def purpose48_derivation(account_id: int, xtype: str) -> str:
return normalize_bip32_derivation(der)
def from_seed(seed: str, *, passphrase: Optional[str], for_multisig: bool = False):
def from_seed(seed: str, *, passphrase: Optional[str], for_multisig: bool = False) -> Union[BIP32_KeyStore, Old_KeyStore]:
passphrase = passphrase or ""
t = calc_seed_type(seed)
if t == 'old':
@@ -1182,28 +1189,28 @@ def from_seed(seed: str, *, passphrase: Optional[str], for_multisig: bool = Fals
raise BitcoinException('Unexpected seed type {}'.format(repr(t)))
return keystore
def from_private_key_list(text):
def from_private_key_list(text: str) -> Imported_KeyStore:
keystore = Imported_KeyStore({})
for x in get_private_keys(text):
keystore.import_privkey(x, None)
return keystore
def from_old_mpk(mpk):
def from_old_mpk(mpk: str) -> Old_KeyStore:
keystore = Old_KeyStore({})
keystore.add_master_public_key(mpk)
return keystore
def from_xpub(xpub):
def from_xpub(xpub: str) -> BIP32_KeyStore:
k = BIP32_KeyStore({})
k.add_xpub(xpub)
return k
def from_xprv(xprv):
def from_xprv(xprv: str) -> BIP32_KeyStore:
k = BIP32_KeyStore({})
k.add_xprv(xprv)
return k
def from_master_key(text):
def from_master_key(text: str) -> Union[BIP32_KeyStore, Old_KeyStore]:
if is_xprv(text):
k = from_xprv(text)
elif is_old_mpk(text):

View File

@@ -23,6 +23,8 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import Sequence, Union
from .mnemonic import Wordlist
@@ -1666,7 +1668,11 @@ assert n == 1626
# Note about US patent no 5892470: Here each word does not represent a given digit.
# Instead, the digit represented by a word is variable, it depends on the previous word.
def mn_encode(message):
def mn_encode(message: Union[str, bytes]) -> Sequence[str]:
# FIXME `message` is either bytes that can only contain hex chars, or is a hex str
# note: to generate an 'old'-type mnemonic for testing:
# " ".join(electrum.old_mnemonic.mn_encode(secrets.token_hex(16)))
#assert is_hex_str(message), f"expected hex, got {type(message)}"
assert len(message) % 8 == 0
out = []
for i in range(len(message)//8):
@@ -1679,7 +1685,7 @@ def mn_encode(message):
return out
def mn_decode(wlist):
def mn_decode(wlist: Sequence[str]) -> str:
out = ''
for i in range(len(wlist)//3):
word1, word2, word3 = wlist[3*i:3*i+3]