1
0

slip39: implement extendable backups

This commit is contained in:
Ondřej Vejpustek
2024-05-13 16:13:33 +02:00
parent 4b5cd0ff2b
commit 70f0ed992f
4 changed files with 300 additions and 42 deletions

View File

@@ -63,11 +63,19 @@ def _xor(a: bytes, b: bytes) -> bytes:
_ID_LENGTH_BITS = 15
"""The length of the random identifier in bits."""
_ITERATION_EXP_LENGTH_BITS = 5
_ITERATION_EXP_LENGTH_BITS = 4
"""The length of the iteration exponent in bits."""
_ID_EXP_LENGTH_WORDS = _bits_to_words(_ID_LENGTH_BITS + _ITERATION_EXP_LENGTH_BITS)
"""The length of the random identifier and iteration exponent in words."""
_EXTENDABLE_BACKUP_FLAG_LENGTH_BITS = 1
"""The length of the extendable backup flag in bits."""
_ID_EXP_LENGTH_WORDS = _bits_to_words(
_ID_LENGTH_BITS + _EXTENDABLE_BACKUP_FLAG_LENGTH_BITS + _ITERATION_EXP_LENGTH_BITS
)
"""The length of the random identifier, extendable backup flag and iteration exponent in words."""
_INDEX_LENGTH_BITS = 4
"""The length of the group index, group threshold, group count, and member index in bits."""
_CHECKSUM_LENGTH_WORDS = 3
"""The length of the RS1024 checksum in words."""
@@ -75,8 +83,11 @@ _CHECKSUM_LENGTH_WORDS = 3
_DIGEST_LENGTH_BYTES = 4
"""The length of the digest of the shared secret in bytes."""
_CUSTOMIZATION_STRING = b"shamir"
"""The customization string used in the RS1024 checksum and in the PBKDF2 salt."""
_CUSTOMIZATION_STRING_NON_EXTENDABLE = b"shamir"
"""The customization string used in the RS1024 checksum and in the PBKDF2 salt when extendable backup flag is not set."""
_CUSTOMIZATION_STRING_EXTENDABLE = b"shamir_extendable"
"""The customization string used in the RS1024 checksum when extendable backup flag is set."""
_GROUP_PREFIX_LENGTH_WORDS = _ID_EXP_LENGTH_WORDS + 1
"""The length of the prefix of the mnemonic that is common to a share group."""
@@ -120,6 +131,7 @@ class Share:
def __init__(
self,
identifier: int,
extendable_backup_flag: bool,
iteration_exponent: int,
group_index: int,
group_threshold: int,
@@ -130,6 +142,7 @@ class Share:
):
self.index = None
self.identifier = identifier
self.extendable_backup_flag = extendable_backup_flag
self.iteration_exponent = iteration_exponent
self.group_index = group_index
self.group_threshold = group_threshold
@@ -142,6 +155,7 @@ class Share:
"""Return the values that uniquely identify a matching set of shares."""
return (
self.identifier,
self.extendable_backup_flag,
self.iteration_exponent,
self.group_threshold,
self.group_count,
@@ -153,8 +167,15 @@ class EncryptedSeed:
Represents the encrypted master seed for BIP-32.
"""
def __init__(self, identifier: int, iteration_exponent: int, encrypted_master_secret: bytes):
def __init__(
self,
identifier: int,
extendable_backup_flag: bool,
iteration_exponent: int,
encrypted_master_secret: bytes,
):
self.identifier = identifier
self.extendable_backup_flag = extendable_backup_flag
self.iteration_exponent = iteration_exponent
self.encrypted_master_secret = encrypted_master_secret
@@ -169,7 +190,7 @@ class EncryptedSeed:
ems_len = len(self.encrypted_master_secret)
l = self.encrypted_master_secret[: ems_len // 2]
r = self.encrypted_master_secret[ems_len // 2 :]
salt = _get_salt(self.identifier)
salt = _get_salt(self.identifier, self.extendable_backup_flag)
for i in reversed(range(_ROUND_COUNT)):
(l, r) = (
r,
@@ -190,6 +211,7 @@ def recover_ems(mnemonics: List[str]) -> EncryptedSeed:
(
identifier,
extendable_backup_flag,
iteration_exponent,
group_threshold,
group_count,
@@ -212,7 +234,9 @@ def recover_ems(mnemonics: List[str]) -> EncryptedSeed:
]
encrypted_master_secret = _recover_secret(group_threshold, group_shares)
return EncryptedSeed(identifier, iteration_exponent, encrypted_master_secret)
return EncryptedSeed(
identifier, extendable_backup_flag, iteration_exponent, encrypted_master_secret
)
def decode_mnemonic(mnemonic: str) -> Share:
@@ -227,12 +251,19 @@ def decode_mnemonic(mnemonic: str) -> Share:
if padding_len > 8:
raise Slip39Error(_('Invalid length.'))
if not _rs1024_verify_checksum(mnemonic_data):
idExpExtInt = _int_from_indices(mnemonic_data[:_ID_EXP_LENGTH_WORDS])
identifier = idExpExtInt >> (
_EXTENDABLE_BACKUP_FLAG_LENGTH_BITS + _ITERATION_EXP_LENGTH_BITS
)
extendable_backup_flag = bool(
(idExpExtInt >> _ITERATION_EXP_LENGTH_BITS)
& ((1 << _EXTENDABLE_BACKUP_FLAG_LENGTH_BITS) - 1)
)
iteration_exponent = idExpExtInt & ((1 << _ITERATION_EXP_LENGTH_BITS) - 1)
if not _rs1024_verify_checksum(mnemonic_data, extendable_backup_flag):
raise Slip39Error(_('Invalid mnemonic checksum.'))
id_exp_int = _int_from_indices(mnemonic_data[:_ID_EXP_LENGTH_WORDS])
identifier = id_exp_int >> _ITERATION_EXP_LENGTH_BITS
iteration_exponent = id_exp_int & ((1 << _ITERATION_EXP_LENGTH_BITS) - 1)
tmp = _int_from_indices(
mnemonic_data[_ID_EXP_LENGTH_WORDS : _ID_EXP_LENGTH_WORDS + 2]
)
@@ -242,7 +273,7 @@ def decode_mnemonic(mnemonic: str) -> Share:
group_count,
member_index,
member_threshold,
) = _int_to_indices(tmp, 5, 4)
) = _int_to_indices(tmp, 5, _INDEX_LENGTH_BITS)
value_data = mnemonic_data[_ID_EXP_LENGTH_WORDS + 2 : -_CHECKSUM_LENGTH_WORDS]
if group_count < group_threshold:
@@ -256,6 +287,7 @@ def decode_mnemonic(mnemonic: str) -> Share:
return Share(
identifier,
extendable_backup_flag,
iteration_exponent,
group_index,
group_threshold + 1,
@@ -314,6 +346,7 @@ def process_mnemonics(mnemonics: List[str]) -> Tuple[Optional[EncryptedSeed], st
groups_completed += 1
identifier = shares[0].identifier
extendable_backup_flag = shares[0].extendable_backup_flag
iteration_exponent = shares[0].iteration_exponent
group_threshold = shares[0].group_threshold
group_count = shares[0].group_count
@@ -323,7 +356,14 @@ def process_mnemonics(mnemonics: List[str]) -> Tuple[Optional[EncryptedSeed], st
status += ":<br/>"
for group_index in range(group_count):
group_prefix = _make_group_prefix(identifier, iteration_exponent, group_index, group_threshold, group_count)
group_prefix = _make_group_prefix(
identifier,
extendable_backup_flag,
iteration_exponent,
group_index,
group_threshold,
group_count,
)
status += _group_status(groups[group_index], group_prefix)
if groups_completed >= group_threshold:
@@ -350,16 +390,25 @@ _EMPTY = '<span style="color:red;">&#x2715;</span>'
_INPROGRESS = '<span style="color:orange;">&#x26ab;</span>'
_ERROR_STYLE = '<span style="color:red; font-weight:bold;">' + _('Error') + ': %s</span>'
def _make_group_prefix(identifier, iteration_exponent, group_index, group_threshold, group_count):
def _make_group_prefix(
identifier,
extendable_backup_flag,
iteration_exponent,
group_index,
group_threshold,
group_count,
):
wordlist = get_wordlist()
val = identifier
val <<= _EXTENDABLE_BACKUP_FLAG_LENGTH_BITS
val += int(extendable_backup_flag)
val <<= _ITERATION_EXP_LENGTH_BITS
val += iteration_exponent
val <<= 4
val <<= _INDEX_LENGTH_BITS
val += group_index
val <<= 4
val <<= _INDEX_LENGTH_BITS
val += group_threshold - 1
val <<= 4
val <<= _INDEX_LENGTH_BITS
val += group_count - 1
val >>= 2
prefix = ' '.join(wordlist[idx] for idx in _int_to_indices(val, _GROUP_PREFIX_LENGTH_WORDS, _RADIX_BITS))
@@ -413,6 +462,13 @@ def _mnemonic_to_indices(mnemonic: str) -> List[int]:
"""
def _get_customization_string(extendable_backup_flag: bool) -> bytes:
if extendable_backup_flag:
return _CUSTOMIZATION_STRING_EXTENDABLE
else:
return _CUSTOMIZATION_STRING_NON_EXTENDABLE
def _rs1024_polymod(values: Indices) -> int:
GEN = (
0xE0E040,
@@ -435,11 +491,14 @@ def _rs1024_polymod(values: Indices) -> int:
return chk
def _rs1024_verify_checksum(data: Indices) -> bool:
def _rs1024_verify_checksum(data: Indices, extendable_backup_flag: bool) -> bool:
"""
Verifies a checksum of the given mnemonic, which was already parsed into Indices.
"""
return _rs1024_polymod(tuple(_CUSTOMIZATION_STRING) + data) == 1
return (
_rs1024_polymod(tuple(_get_customization_string(extendable_backup_flag)) + data)
== 1
)
"""
@@ -532,10 +591,13 @@ def _round_function(i: int, passphrase: bytes, e: int, salt: bytes, r: bytes) ->
)
def _get_salt(identifier: int) -> bytes:
return _CUSTOMIZATION_STRING + identifier.to_bytes(
_bits_to_bytes(_ID_LENGTH_BITS), "big"
)
def _get_salt(identifier: int, extendable_backup_flag: bool) -> bytes:
if extendable_backup_flag:
return bytes()
else:
return _CUSTOMIZATION_STRING_NON_EXTENDABLE + identifier.to_bytes(
_bits_to_bytes(_ID_LENGTH_BITS), "big"
)
def _create_digest(random_data: bytes, shared_secret: bytes) -> bytes:
@@ -562,6 +624,7 @@ def _decode_mnemonics(
mnemonics: List[str],
) -> Tuple[int, int, int, int, MnemonicGroups]:
identifiers = set()
extendable_backup_flags = set()
iteration_exponents = set()
group_thresholds = set()
group_counts = set()
@@ -571,6 +634,7 @@ def _decode_mnemonics(
for mnemonic in mnemonics:
share = decode_mnemonic(mnemonic)
identifiers.add(share.identifier)
extendable_backup_flags.add(share.extendable_backup_flag)
iteration_exponents.add(share.iteration_exponent)
group_thresholds.add(share.group_threshold)
group_counts.add(share.group_count)
@@ -581,7 +645,11 @@ def _decode_mnemonics(
)
group[1].add((share.member_index, share.share_value))
if len(identifiers) != 1 or len(iteration_exponents) != 1:
if (
len(identifiers) != 1
or len(extendable_backup_flags) != 1
or len(iteration_exponents) != 1
):
raise Slip39Error(
"Invalid set of mnemonics. All mnemonics must begin with the same {} words.".format(
_ID_EXP_LENGTH_WORDS
@@ -606,6 +674,7 @@ def _decode_mnemonics(
return (
identifiers.pop(),
extendable_backup_flags.pop(),
iteration_exponents.pop(),
group_thresholds.pop(),
group_counts.pop(),