1
0

storage_db: fix tests, add modified flag to db class

This commit is contained in:
ThomasV
2019-02-28 11:55:15 +01:00
parent dbca0a0e83
commit d74f0c0947
5 changed files with 60 additions and 37 deletions

View File

@@ -26,6 +26,7 @@ import os
import ast
import json
import copy
import threading
from collections import defaultdict
from typing import Dict
@@ -45,7 +46,9 @@ FINAL_SEED_VERSION = 18 # electrum >= 2.7 will set this to prevent
class JsonDB(PrintError):
def __init__(self, raw, *, manual_upgrades):
self.lock = threading.RLock()
self.data = {}
self._modified = False
self.manual_upgrades = manual_upgrades
if raw:
self.load_data(raw)
@@ -53,6 +56,20 @@ class JsonDB(PrintError):
self.put('seed_version', FINAL_SEED_VERSION)
self.load_transactions()
def set_modified(self, b):
with self.lock:
self._modified = b
def modified(self):
return self._modified
def modifier(func):
def wrapper(self, *args, **kwargs):
with self.lock:
self._modified = True
return func(self, *args, **kwargs)
return wrapper
def get(self, key, default=None):
v = self.data.get(key)
if v is None:
@@ -61,6 +78,7 @@ class JsonDB(PrintError):
v = copy.deepcopy(v)
return v
@modifier
def put(self, key, value):
try:
json.dumps(key, cls=util.MyEncoder)
@@ -483,6 +501,7 @@ class JsonDB(PrintError):
def get_txo_addr(self, tx_hash, address):
return self.txo.get(tx_hash, {}).get(address, [])
@modifier
def add_txi_addr(self, tx_hash, addr, ser, v):
if tx_hash not in self.txi:
self.txi[tx_hash] = {}
@@ -492,6 +511,7 @@ class JsonDB(PrintError):
d[addr] = set()
d[addr].add((ser, v))
@modifier
def add_txo_addr(self, tx_hash, addr, n, v, is_coinbase):
if tx_hash not in self.txo:
self.txo[tx_hash] = {}
@@ -507,26 +527,43 @@ class JsonDB(PrintError):
def get_txo_keys(self):
return self.txo.keys()
@modifier
def remove_txi(self, tx_hash):
self.txi.pop(tx_hash, None)
@modifier
def remove_txo(self, tx_hash):
self.txo.pop(tx_hash, None)
def list_spent_outpoints(self):
return [(h, n)
for h in self.spent_outpoints.keys()
for n in self.get_spent_outpoints(h)
]
def get_spent_outpoints(self, prevout_hash):
return self.spent_outpoints.get(prevout_hash, {}).keys()
def get_spent_outpoint(self, prevout_hash, prevout_n):
return self.spent_outpoints.get(prevout_hash, {}).get(str(prevout_n))
@modifier
def remove_spent_outpoint(self, prevout_hash, prevout_n):
self.spent_outpoints[prevout_hash].pop(prevout_n, None) # FIXME
if not self.spent_outpoints[prevout_hash]:
self.spent_outpoints.pop(prevout_hash)
@modifier
def set_spent_outpoint(self, prevout_hash, prevout_n, tx_hash):
if prevout_hash not in self.spent_outpoints:
self.spent_outpoints[prevout_hash] = {}
self.spent_outpoints[prevout_hash][str(prevout_n)] = tx_hash
@modifier
def add_transaction(self, tx_hash, tx):
self.transactions[tx_hash] = str(tx)
@modifier
def remove_transaction(self, tx_hash):
self.transactions.pop(tx_hash, None)
@@ -543,9 +580,11 @@ class JsonDB(PrintError):
def get_addr_history(self, addr):
return self.history.get(addr, [])
@modifier
def set_addr_history(self, addr, hist):
self.history[addr] = hist
@modifier
def remove_addr_history(self, addr):
self.history.pop(addr, None)
@@ -562,18 +601,22 @@ class JsonDB(PrintError):
txpos=txpos,
header_hash=header_hash)
@modifier
def add_verified_tx(self, txid, info):
self.verified_tx[txid] = (info.height, info.timestamp, info.txpos, info.header_hash)
@modifier
def remove_verified_tx(self, txid):
self.verified_tx.pop(txid, None)
@modifier
def update_tx_fees(self, d):
return self.tx_fees.update(d)
def get_tx_fee(self, txid):
return self.tx_fees.get(txid)
@modifier
def remove_tx_fee(self, txid):
self.tx_fees.pop(txid, None)