storage: encapsulate type conversions of stored objects using
decorators (instead of overloading JsonDB._convert_dict and _convert_value) - stored_in for elements of a StoreDict - stored_as for singletons - extra register methods are defined for key conversions This commit was adapted from the jsonpatch branch
This commit is contained in:
@@ -4,7 +4,7 @@ from decimal import Decimal
|
||||
|
||||
import attr
|
||||
|
||||
from .json_db import StoredObject
|
||||
from .json_db import StoredObject, stored_in
|
||||
from .i18n import _
|
||||
from .util import age, InvoiceError, format_satoshis
|
||||
from .lnutil import hex_to_bytes
|
||||
@@ -244,6 +244,7 @@ class BaseInvoice(StoredObject):
|
||||
return d
|
||||
|
||||
|
||||
@stored_in('invoices')
|
||||
@attr.s
|
||||
class Invoice(BaseInvoice):
|
||||
lightning_invoice = attr.ib(type=str, kw_only=True) # type: Optional[str]
|
||||
@@ -303,6 +304,7 @@ class Invoice(BaseInvoice):
|
||||
return d
|
||||
|
||||
|
||||
@stored_in('payment_requests')
|
||||
@attr.s
|
||||
class Request(BaseInvoice):
|
||||
payment_hash = attr.ib(type=bytes, kw_only=True, converter=hex_to_bytes) # type: Optional[bytes]
|
||||
|
||||
@@ -45,6 +45,28 @@ def locked(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
registered_names = {}
|
||||
registered_dicts = {}
|
||||
registered_dict_keys = {}
|
||||
registered_parent_keys = {}
|
||||
|
||||
|
||||
def stored_as(name, _type=dict):
|
||||
""" decorator that indicates the storage key of a stored object"""
|
||||
def decorator(func):
|
||||
registered_names[name] = func, _type
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def stored_in(name, _type=dict):
|
||||
""" decorator that indicates the storage key of an element in a StoredDict"""
|
||||
def decorator(func):
|
||||
registered_dicts[name] = func, _type
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
|
||||
class StoredObject:
|
||||
|
||||
db = None
|
||||
@@ -195,3 +217,46 @@ class JsonDB(Logger):
|
||||
|
||||
def _should_convert_to_stored_dict(self, key) -> bool:
|
||||
return True
|
||||
|
||||
def register_dict(self, name, method, _type):
|
||||
registered_dicts[name] = method, _type
|
||||
|
||||
def register_name(self, name, method, _type):
|
||||
registered_names[name] = method, _type
|
||||
|
||||
def register_dict_key(self, name, method):
|
||||
registered_dict_keys[name] = method
|
||||
|
||||
def register_parent_key(self, name, method):
|
||||
registered_parent_keys[name] = method
|
||||
|
||||
def _convert_dict(self, path, key, v):
|
||||
|
||||
if key in registered_dicts:
|
||||
constructor, _type = registered_dicts[key]
|
||||
if _type == dict:
|
||||
v = dict((k, constructor(**x)) for k, x in v.items())
|
||||
elif _type == tuple:
|
||||
v = dict((k, constructor(*x)) for k, x in v.items())
|
||||
else:
|
||||
v = dict((k, constructor(x)) for k, x in v.items())
|
||||
|
||||
if key in registered_dict_keys:
|
||||
convert_key = registered_dict_keys[key]
|
||||
elif path and path[-1] in registered_parent_keys:
|
||||
convert_key = registered_parent_keys.get(path[-1])
|
||||
else:
|
||||
convert_key = None
|
||||
if convert_key:
|
||||
v = dict((convert_key(k), x) for k, x in v.items())
|
||||
|
||||
return v
|
||||
|
||||
def _convert_value(self, path, key, v):
|
||||
if key in registered_names:
|
||||
constructor, _type = registered_names[key]
|
||||
if _type == dict:
|
||||
v = constructor(**v)
|
||||
else:
|
||||
v = constructor(v)
|
||||
return v
|
||||
|
||||
@@ -52,7 +52,7 @@ DUST_LIMIT_MAX = 1000
|
||||
def ln_dummy_address():
|
||||
return redeem_script_to_address('p2wsh', '')
|
||||
|
||||
from .json_db import StoredObject
|
||||
from .json_db import StoredObject, stored_in, stored_as
|
||||
|
||||
|
||||
def channel_id_from_funding_tx(funding_txid: str, funding_index: int) -> Tuple[bytes, bytes]:
|
||||
@@ -181,6 +181,7 @@ class ChannelConfig(StoredObject):
|
||||
raise Exception(f"feerate lower than min relay fee. {initial_feerate_per_kw} sat/kw.")
|
||||
|
||||
|
||||
@stored_as('local_config')
|
||||
@attr.s
|
||||
class LocalConfig(ChannelConfig):
|
||||
channel_seed = attr.ib(type=bytes, converter=hex_to_bytes) # type: Optional[bytes]
|
||||
@@ -214,17 +215,20 @@ class LocalConfig(ChannelConfig):
|
||||
if self.htlc_minimum_msat < HTLC_MINIMUM_MSAT_MIN:
|
||||
raise Exception(f"{conf_name}. htlc_minimum_msat too low: {self.htlc_minimum_msat} msat < {HTLC_MINIMUM_MSAT_MIN}")
|
||||
|
||||
@stored_as('remote_config')
|
||||
@attr.s
|
||||
class RemoteConfig(ChannelConfig):
|
||||
next_per_commitment_point = attr.ib(type=bytes, converter=hex_to_bytes)
|
||||
current_per_commitment_point = attr.ib(default=None, type=bytes, converter=hex_to_bytes)
|
||||
|
||||
@stored_in('fee_updates')
|
||||
@attr.s
|
||||
class FeeUpdate(StoredObject):
|
||||
rate = attr.ib(type=int) # in sat/kw
|
||||
ctn_local = attr.ib(default=None, type=int)
|
||||
ctn_remote = attr.ib(default=None, type=int)
|
||||
|
||||
@stored_as('constraints')
|
||||
@attr.s
|
||||
class ChannelConstraints(StoredObject):
|
||||
capacity = attr.ib(type=int) # in sat
|
||||
@@ -248,10 +252,12 @@ class ChannelBackupStorage(StoredObject):
|
||||
chan_id, _ = channel_id_from_funding_tx(self.funding_txid, self.funding_index)
|
||||
return chan_id
|
||||
|
||||
@stored_in('onchain_channel_backups')
|
||||
@attr.s
|
||||
class OnchainChannelBackupStorage(ChannelBackupStorage):
|
||||
node_id_prefix = attr.ib(type=bytes, converter=hex_to_bytes)
|
||||
|
||||
@stored_in('imported_channel_backups')
|
||||
@attr.s
|
||||
class ImportedChannelBackupStorage(ChannelBackupStorage):
|
||||
node_id = attr.ib(type=bytes, converter=hex_to_bytes)
|
||||
@@ -320,6 +326,7 @@ class ScriptHtlc(NamedTuple):
|
||||
|
||||
|
||||
# FIXME duplicate of TxOutpoint in transaction.py??
|
||||
@stored_as('funding_outpoint')
|
||||
@attr.s
|
||||
class Outpoint(StoredObject):
|
||||
txid = attr.ib(type=str)
|
||||
@@ -484,8 +491,17 @@ def shachain_derive(element, to_index):
|
||||
get_per_commitment_secret_from_seed(element.secret, to_index, zeros),
|
||||
to_index)
|
||||
|
||||
ShachainElement = namedtuple("ShachainElement", ["secret", "index"])
|
||||
ShachainElement.__str__ = lambda self: f"ShachainElement({self.secret.hex()},{self.index})"
|
||||
class ShachainElement(NamedTuple):
|
||||
secret: bytes
|
||||
index: int
|
||||
|
||||
def __str__(self):
|
||||
return "ShachainElement(" + self.secret.hex() + "," + str(self.index) + ")"
|
||||
|
||||
@stored_in('buckets', tuple)
|
||||
def read(*x):
|
||||
return ShachainElement(bfh(x[0]), int(x[1]))
|
||||
|
||||
|
||||
def get_per_commitment_secret_from_seed(seed: bytes, i: int, bits: int = 48) -> bytes:
|
||||
"""Generate per commitment secret."""
|
||||
@@ -1226,6 +1242,7 @@ class LnFeatures(IntFlag):
|
||||
return hex(self._value_)
|
||||
|
||||
|
||||
@stored_as('channel_type', _type=None)
|
||||
class ChannelType(IntFlag):
|
||||
OPTION_LEGACY_CHANNEL = 0
|
||||
OPTION_STATIC_REMOTEKEY = 1 << 12
|
||||
@@ -1546,15 +1563,16 @@ class UpdateAddHtlc:
|
||||
timestamp = attr.ib(type=int, kw_only=True)
|
||||
htlc_id = attr.ib(type=int, kw_only=True, default=None)
|
||||
|
||||
@classmethod
|
||||
def from_tuple(cls, amount_msat, payment_hash, cltv_expiry, htlc_id, timestamp) -> 'UpdateAddHtlc':
|
||||
return cls(amount_msat=amount_msat,
|
||||
payment_hash=payment_hash,
|
||||
cltv_expiry=cltv_expiry,
|
||||
htlc_id=htlc_id,
|
||||
timestamp=timestamp)
|
||||
@stored_in('adds', tuple)
|
||||
def from_tuple(amount_msat, payment_hash, cltv_expiry, htlc_id, timestamp) -> 'UpdateAddHtlc':
|
||||
return UpdateAddHtlc(
|
||||
amount_msat=amount_msat,
|
||||
payment_hash=payment_hash,
|
||||
cltv_expiry=cltv_expiry,
|
||||
htlc_id=htlc_id,
|
||||
timestamp=timestamp)
|
||||
|
||||
def to_tuple(self):
|
||||
def to_json(self):
|
||||
return (self.amount_msat, self.payment_hash, self.cltv_expiry, self.htlc_id, self.timestamp)
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from .lnutil import REDEEM_AFTER_DOUBLE_SPENT_DELAY, ln_dummy_address
|
||||
from .bitcoin import dust_threshold
|
||||
from .logging import Logger
|
||||
from .lnutil import hex_to_bytes
|
||||
from .json_db import StoredObject
|
||||
from .json_db import StoredObject, stored_in
|
||||
from . import constants
|
||||
from .address_synchronizer import TX_HEIGHT_LOCAL
|
||||
from .i18n import _
|
||||
@@ -87,6 +87,7 @@ class SwapServerError(Exception):
|
||||
return _("The swap server errored or is unreachable.")
|
||||
|
||||
|
||||
@stored_in('submarine_swaps')
|
||||
@attr.s
|
||||
class SwapData(StoredObject):
|
||||
is_reverse = attr.ib(type=bool)
|
||||
|
||||
@@ -53,6 +53,7 @@ from .crypto import sha256d
|
||||
from .logging import get_logger
|
||||
from .util import ShortID, OldTaskGroup
|
||||
from .descriptor import Descriptor, MissingSolutionPiece, create_dummy_descriptor_from_address
|
||||
from .json_db import stored_in
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wallet import Abstract_Wallet
|
||||
|
||||
@@ -297,9 +297,6 @@ class MyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
# note: this does not get called for namedtuples :( https://bugs.python.org/issue30343
|
||||
from .transaction import Transaction, TxOutput
|
||||
from .lnutil import UpdateAddHtlc
|
||||
if isinstance(obj, UpdateAddHtlc):
|
||||
return obj.to_tuple()
|
||||
if isinstance(obj, Transaction):
|
||||
return obj.serialize()
|
||||
if isinstance(obj, TxOutput):
|
||||
|
||||
@@ -41,12 +41,10 @@ from .invoices import Invoice, Request
|
||||
from .keystore import bip44_derivation
|
||||
from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput
|
||||
from .logging import Logger
|
||||
from .lnutil import LOCAL, REMOTE, FeeUpdate, UpdateAddHtlc, LocalConfig, RemoteConfig, ChannelType
|
||||
from .lnutil import ImportedChannelBackupStorage, OnchainChannelBackupStorage
|
||||
from .lnutil import ChannelConstraints, Outpoint, ShachainElement
|
||||
from .json_db import StoredDict, JsonDB, locked, modifier, StoredObject
|
||||
|
||||
from .lnutil import LOCAL, REMOTE, HTLCOwner, ChannelType
|
||||
from .json_db import StoredDict, JsonDB, locked, modifier, StoredObject, stored_in, stored_as
|
||||
from .plugin import run_hook, plugin_loaders
|
||||
from .submarine_swaps import SwapData
|
||||
from .version import ELECTRUM_VERSION
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -61,12 +59,14 @@ FINAL_SEED_VERSION = 52 # electrum >= 2.7 will set this to prevent
|
||||
# old versions from overwriting new format
|
||||
|
||||
|
||||
@stored_in('tx_fees', tuple)
|
||||
class TxFeesValue(NamedTuple):
|
||||
fee: Optional[int] = None
|
||||
is_calculated_by_us: bool = False
|
||||
num_inputs: Optional[int] = None
|
||||
|
||||
|
||||
@stored_as('db_metadata')
|
||||
@attr.s
|
||||
class DBMetadata(StoredObject):
|
||||
creation_timestamp = attr.ib(default=None, type=int)
|
||||
@@ -91,6 +91,20 @@ class WalletDB(JsonDB):
|
||||
|
||||
def __init__(self, raw, *, manual_upgrades: bool):
|
||||
JsonDB.__init__(self, {})
|
||||
# register dicts that require value conversions not handled by constructor
|
||||
self.register_dict('transactions', lambda x: tx_from_any(x, deserialize=False), None)
|
||||
self.register_dict('prevouts_by_scripthash', lambda x: set(tuple(k) for k in x), None)
|
||||
self.register_dict('data_loss_protect_remote_pcp', lambda x: bytes.fromhex(x), None)
|
||||
# register dicts that require key conversion
|
||||
for key in [
|
||||
'adds', 'locked_in', 'settles', 'fails', 'fee_updates', 'buckets',
|
||||
'unacked_updates', 'unfulfilled_htlcs', 'fail_htlc_reasons', 'onion_keys']:
|
||||
self.register_dict_key(key, int)
|
||||
for key in ['log']:
|
||||
self.register_dict_key(key, lambda x: HTLCOwner(int(x)))
|
||||
for key in ['locked_in', 'fails', 'settles']:
|
||||
self.register_parent_key(key, lambda x: HTLCOwner(int(x)))
|
||||
|
||||
self._manual_upgrades = manual_upgrades
|
||||
self._called_after_upgrade_tasks = False
|
||||
if raw: # loading existing db
|
||||
@@ -1560,58 +1574,6 @@ class WalletDB(JsonDB):
|
||||
self.tx_fees.clear()
|
||||
self._prevouts_by_scripthash.clear()
|
||||
|
||||
def _convert_dict(self, path, key, v):
|
||||
if key == 'transactions':
|
||||
# note: for performance, "deserialize=False" so that we will deserialize these on-demand
|
||||
v = dict((k, tx_from_any(x, deserialize=False)) for k, x in v.items())
|
||||
if key == 'invoices':
|
||||
v = dict((k, Invoice(**x)) for k, x in v.items())
|
||||
if key == 'payment_requests':
|
||||
v = dict((k, Request(**x)) for k, x in v.items())
|
||||
elif key == 'adds':
|
||||
v = dict((k, UpdateAddHtlc.from_tuple(*x)) for k, x in v.items())
|
||||
elif key == 'fee_updates':
|
||||
v = dict((k, FeeUpdate(**x)) for k, x in v.items())
|
||||
elif key == 'submarine_swaps':
|
||||
v = dict((k, SwapData(**x)) for k, x in v.items())
|
||||
elif key == 'imported_channel_backups':
|
||||
v = dict((k, ImportedChannelBackupStorage(**x)) for k, x in v.items())
|
||||
elif key == 'onchain_channel_backups':
|
||||
v = dict((k, OnchainChannelBackupStorage(**x)) for k, x in v.items())
|
||||
elif key == 'tx_fees':
|
||||
v = dict((k, TxFeesValue(*x)) for k, x in v.items())
|
||||
elif key == 'prevouts_by_scripthash':
|
||||
v = dict((k, {(prevout, value) for (prevout, value) in x}) for k, x in v.items())
|
||||
elif key == 'buckets':
|
||||
v = dict((k, ShachainElement(bfh(x[0]), int(x[1]))) for k, x in v.items())
|
||||
elif key == 'data_loss_protect_remote_pcp':
|
||||
v = dict((k, bfh(x)) for k, x in v.items())
|
||||
# convert htlc_id keys to int
|
||||
if key in ['adds', 'locked_in', 'settles', 'fails', 'fee_updates', 'buckets',
|
||||
'unacked_updates', 'unfulfilled_htlcs', 'fail_htlc_reasons', 'onion_keys']:
|
||||
v = dict((int(k), x) for k, x in v.items())
|
||||
# convert keys to HTLCOwner
|
||||
if key == 'log' or (path and path[-1] in ['locked_in', 'fails', 'settles']):
|
||||
if "1" in v:
|
||||
v[LOCAL] = v.pop("1")
|
||||
v[REMOTE] = v.pop("-1")
|
||||
return v
|
||||
|
||||
def _convert_value(self, path, key, v):
|
||||
if key == 'local_config':
|
||||
v = LocalConfig(**v)
|
||||
elif key == 'remote_config':
|
||||
v = RemoteConfig(**v)
|
||||
elif key == 'constraints':
|
||||
v = ChannelConstraints(**v)
|
||||
elif key == 'funding_outpoint':
|
||||
v = Outpoint(**v)
|
||||
elif key == 'channel_type':
|
||||
v = ChannelType(v)
|
||||
elif key == 'db_metadata':
|
||||
v = DBMetadata(**v)
|
||||
return v
|
||||
|
||||
def _should_convert_to_stored_dict(self, key) -> bool:
|
||||
if key == 'keystore':
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user