1
0

Restructure wallet storage:

- Perform json deserializations in wallet_db
 - use StoredDict class that keeps tracks of its modifications
This commit is contained in:
ThomasV
2020-02-04 13:35:58 +01:00
parent 0a9e7cb04e
commit dbceed2647
14 changed files with 303 additions and 291 deletions

View File

@@ -29,12 +29,16 @@ import copy
import threading
from collections import defaultdict
from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence
import binascii
from . import util, bitcoin
from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, bfh
from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, bfh, PR_TYPE_ONCHAIN
from .keystore import bip44_derivation
from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction
from .json_db import JsonDB, locked, modifier
from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput
from .logging import Logger
from .lnutil import LOCAL, REMOTE, FeeUpdate, UpdateAddHtlc, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, RevocationStore
from .lnutil import ChannelConstraints, Outpoint, ShachainElement
from .json_db import StoredDict, JsonDB, locked, modifier
# seed_version is now used for the version of the wallet file
@@ -44,17 +48,12 @@ FINAL_SEED_VERSION = 24 # electrum >= 2.7 will set this to prevent
# old versions from overwriting new format
class TxFeesValue(NamedTuple):
fee: Optional[int] = None
is_calculated_by_us: bool = False
num_inputs: Optional[int] = None
class WalletDB(JsonDB):
def __init__(self, raw, *, manual_upgrades: bool):
@@ -67,7 +66,6 @@ class WalletDB(JsonDB):
self.put('seed_version', FINAL_SEED_VERSION)
self._after_upgrade_tasks()
def load_data(self, s):
try:
self.data = json.loads(s)
@@ -833,7 +831,7 @@ class WalletDB(JsonDB):
self.tx_fees.pop(txid, None)
@locked
def get_data_ref(self, name):
def get_dict(self, name):
# Warning: interacts un-intuitively with 'put': certain parts
# of 'data' will have pointers saved as separate variables.
if name not in self.data:
@@ -895,9 +893,9 @@ class WalletDB(JsonDB):
def load_addresses(self, wallet_type):
""" called from Abstract_Wallet.__init__ """
if wallet_type == 'imported':
self.imported_addresses = self.get_data_ref('addresses') # type: Dict[str, dict]
self.imported_addresses = self.get_dict('addresses') # type: Dict[str, dict]
else:
self.get_data_ref('addresses')
self.get_dict('addresses')
for name in ['receiving', 'change']:
if name not in self.data['addresses']:
self.data['addresses'][name] = []
@@ -911,26 +909,20 @@ class WalletDB(JsonDB):
@profiler
def _load_transactions(self):
self.data = StoredDict(self.data, self, [])
# references in self.data
# TODO make all these private
# txid -> address -> set of (prev_outpoint, value)
self.txi = self.get_data_ref('txi') # type: Dict[str, Dict[str, Set[Tuple[str, int]]]]
self.txi = self.get_dict('txi') # type: Dict[str, Dict[str, Set[Tuple[str, int]]]]
# txid -> address -> set of (output_index, value, is_coinbase)
self.txo = self.get_data_ref('txo') # type: Dict[str, Dict[str, Set[Tuple[int, int, bool]]]]
self.transactions = self.get_data_ref('transactions') # type: Dict[str, Transaction]
self.spent_outpoints = self.get_data_ref('spent_outpoints') # txid -> output_index -> next_txid
self.history = self.get_data_ref('addr_history') # address -> list of (txid, height)
self.verified_tx = self.get_data_ref('verified_tx3') # txid -> (height, timestamp, txpos, header_hash)
self.tx_fees = self.get_data_ref('tx_fees') # type: Dict[str, TxFeesValue]
self.txo = self.get_dict('txo') # type: Dict[str, Dict[str, Set[Tuple[int, int, bool]]]]
self.transactions = self.get_dict('transactions') # type: Dict[str, Transaction]
self.spent_outpoints = self.get_dict('spent_outpoints') # txid -> output_index -> next_txid
self.history = self.get_dict('addr_history') # address -> list of (txid, height)
self.verified_tx = self.get_dict('verified_tx3') # txid -> (height, timestamp, txpos, header_hash)
self.tx_fees = self.get_dict('tx_fees') # type: Dict[str, TxFeesValue]
# scripthash -> set of (outpoint, value)
self._prevouts_by_scripthash = self.get_data_ref('prevouts_by_scripthash') # type: Dict[str, Set[Tuple[str, int]]]
# convert raw transactions to Transaction objects
for tx_hash, raw_tx in self.transactions.items():
# note: for performance, "deserialize=False" so that we will deserialize these on-demand
self.transactions[tx_hash] = tx_from_any(raw_tx, deserialize=False)
# convert prevouts_by_scripthash: list to set, list to tuple
for scripthash, lst in self._prevouts_by_scripthash.items():
self._prevouts_by_scripthash[scripthash] = {(prevout, value) for prevout, value in lst}
self._prevouts_by_scripthash = self.get_dict('prevouts_by_scripthash') # type: Dict[str, Set[Tuple[str, int]]]
# remove unreferenced tx
for tx_hash in list(self.transactions.keys()):
if not self.get_txi_addresses(tx_hash) and not self.get_txo_addresses(tx_hash):
@@ -943,9 +935,15 @@ class WalletDB(JsonDB):
if spending_txid not in self.transactions:
self.logger.info("removing unreferenced spent outpoint")
d.pop(prevout_n)
# convert tx_fees tuples to NamedTuples
for tx_hash, tuple_ in self.tx_fees.items():
self.tx_fees[tx_hash] = TxFeesValue(*tuple_)
# convert invoices
# TODO invoices being these contextual dicts even internally,
# where certain keys are only present depending on values of other keys...
# it's horrible. we need to change this, at least for the internal representation,
# to something that can be typed.
self.invoices = self.get_dict('invoices')
for invoice_key, invoice in self.invoices.items():
if invoice.get('type') == PR_TYPE_ONCHAIN:
invoice['outputs'] = [PartialTxOutput.from_legacy_tuple(*output) for output in invoice.get('outputs')]
@modifier
def clear_history(self):
@@ -956,3 +954,42 @@ class WalletDB(JsonDB):
self.history.clear()
self.verified_tx.clear()
self.tx_fees.clear()
def _convert_dict(self, path, key, v):
if key == 'transactions':
# note: for performance, "deserialize=False" so that we will deserialize these on-demand
v = dict((k, tx_from_any(x, deserialize=False)) for k, x in v.items())
elif key == 'adds':
v = dict((k, UpdateAddHtlc(*x)) for k, x in v.items())
elif key == 'fee_updates':
v = dict((k, FeeUpdate(**x)) for k, x in v.items())
elif key == 'tx_fees':
v = dict((k, TxFeesValue(*x)) for k, x in v.items())
elif key == 'prevouts_by_scripthash':
v = dict((k, {(prevout, value) for (prevout, value) in x}) for k, x in v.items())
elif key == 'buckets':
v = dict((k, ShachainElement(bfh(x[0]), int(x[1]))) for k, x in v.items())
return v
def _convert_value(self, path, key, v):
if key == 'local_config':
v = LocalConfig(**v)
elif key == 'remote_config':
v = RemoteConfig(**v)
elif key == 'constraints':
v = ChannelConstraints(**v)
elif key == 'funding_outpoint':
v = Outpoint(**v)
elif key.endswith("_basepoint") or key.endswith("_key"):
v = Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v)
elif key in [
"short_channel_id",
"current_per_commitment_point",
"next_per_commitment_point",
"per_commitment_secret_seed",
"current_commitment_signature",
"current_htlc_signatures"]:
v = binascii.unhexlify(v) if v is not None else None
elif len(path) > 2 and path[-2] in ['local_config', 'remote_config'] and key in ["pubkey", "privkey"]:
v = binascii.unhexlify(v) if v is not None else None
return v