Merge pull request #10233 from spesmilo/jsondb_pointers
Jsondb pointers
This commit is contained in:
@@ -108,91 +108,111 @@ def key_path(path: Sequence[Union[str, int]], key: Optional[str]) -> str:
|
||||
items = [to_str(x) for x in path]
|
||||
if key is not None:
|
||||
items.append(to_str(key))
|
||||
return '/' + '/'.join(items)
|
||||
return '/'.join(items)
|
||||
|
||||
class BaseStoredObject:
|
||||
|
||||
_db: 'JsonDB' = None
|
||||
_key = None
|
||||
_parent = None
|
||||
_lock = None
|
||||
|
||||
def set_db(self, db):
|
||||
self._db = db
|
||||
self._lock = self._db.lock if self._db else threading.RLock()
|
||||
|
||||
def set_parent(self, key, parent):
|
||||
self._key = key
|
||||
self._parent = parent
|
||||
|
||||
@property
|
||||
def lock(self):
|
||||
return self._lock
|
||||
|
||||
@property
|
||||
def path(self) -> Sequence[str]:
|
||||
# return None iff we are pruned from root
|
||||
x = self
|
||||
s = [x._key]
|
||||
while x._parent is not None:
|
||||
x = x._parent
|
||||
s = [x._key] + s
|
||||
if x._key != '':
|
||||
return None
|
||||
assert self._db is not None
|
||||
return s
|
||||
|
||||
def db_add(self, key, value):
|
||||
if self.path:
|
||||
self._db.add(self.path, key, value)
|
||||
|
||||
def db_replace(self, key, value):
|
||||
if self.path:
|
||||
self._db.replace(self.path, key, value)
|
||||
|
||||
def db_remove(self, key):
|
||||
if self.path:
|
||||
self._db.remove(self.path, key)
|
||||
|
||||
|
||||
class StoredObject:
|
||||
|
||||
db: 'JsonDB' = None
|
||||
path = None
|
||||
class StoredObject(BaseStoredObject):
|
||||
"""for attr.s objects """
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if self.db and key not in ['path', 'db'] and not key.startswith('_'):
|
||||
if self.path and not key.startswith('_'):
|
||||
if value != getattr(self, key):
|
||||
self.db.add_patch({'op': 'replace', 'path': key_path(self.path, key), 'value': value})
|
||||
self.db_replace(key, value)
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
def set_db(self, db, path):
|
||||
self.db = db
|
||||
self.path = path
|
||||
|
||||
def to_json(self):
|
||||
d = dict(vars(self))
|
||||
d.pop('db', None)
|
||||
d.pop('path', None)
|
||||
# don't expose/store private stuff
|
||||
d = {k: v for k, v in d.items()
|
||||
if not k.startswith('_')}
|
||||
return d
|
||||
|
||||
|
||||
|
||||
_RaiseKeyError = object() # singleton for no-default behavior
|
||||
|
||||
class StoredDict(dict):
|
||||
|
||||
def __init__(self, data, db: 'JsonDB', path):
|
||||
self.db = db
|
||||
self.lock = self.db.lock if self.db else threading.RLock()
|
||||
self.path = path
|
||||
class StoredDict(dict, BaseStoredObject):
|
||||
|
||||
def __init__(self, data: dict, db: 'JsonDB'):
|
||||
self.set_db(db)
|
||||
# recursively convert dicts to StoredDict
|
||||
for k, v in list(data.items()):
|
||||
self.__setitem__(k, v, patch=False)
|
||||
self.__setitem__(k, v)
|
||||
|
||||
@locked
|
||||
def __setitem__(self, key, v, patch=True):
|
||||
def __setitem__(self, key, v):
|
||||
is_new = key not in self
|
||||
# early return to prevent unnecessary disk writes
|
||||
if not is_new and patch:
|
||||
if self.db and json.dumps(v, cls=self.db.encoder) == json.dumps(self[key], cls=self.db.encoder):
|
||||
return
|
||||
# recursively set db and path
|
||||
if isinstance(v, StoredDict):
|
||||
#assert v.db is None
|
||||
v.db = self.db
|
||||
v.path = self.path + [key]
|
||||
for k, vv in v.items():
|
||||
v.__setitem__(k, vv, patch=False)
|
||||
# recursively convert dict to StoredDict.
|
||||
# _convert_dict is called breadth-first
|
||||
elif isinstance(v, dict):
|
||||
if self.db:
|
||||
v = self.db._convert_dict(self.path, key, v)
|
||||
if not self.db or self.db._should_convert_to_stored_dict(key):
|
||||
v = StoredDict(v, self.db, self.path + [key])
|
||||
# convert_value is called depth-first
|
||||
if isinstance(v, dict) or isinstance(v, str) or isinstance(v, int):
|
||||
if self.db:
|
||||
v = self.db._convert_value(self.path, key, v)
|
||||
# set parent of StoredObject
|
||||
if isinstance(v, StoredObject):
|
||||
v.set_db(self.db, self.path + [key])
|
||||
# convert lists
|
||||
if isinstance(v, list):
|
||||
v = StoredList(v, self.db, self.path + [key])
|
||||
if not is_new and self._db and json.dumps(v, cls=self._db.encoder) == json.dumps(self[key], cls=self._db.encoder):
|
||||
return
|
||||
# convert dict to StoredDict.
|
||||
if type(v) == dict and (self._db is None or self._db._should_convert_to_stored_dict(key)):
|
||||
v = StoredDict(v, self._db)
|
||||
# convert list to StoredList
|
||||
elif type(v) == list:
|
||||
v = StoredList(v, self._db)
|
||||
# reject sets. they do not work well with jsonpatch
|
||||
if isinstance(v, set):
|
||||
elif isinstance(v, set):
|
||||
raise Exception(f"Do not store sets inside jsondb. path={self.path!r}")
|
||||
# set db for StoredObject, because it is not set in the constructor
|
||||
if isinstance(v, StoredObject):
|
||||
v.set_db(self._db)
|
||||
# set parent
|
||||
if isinstance(v, BaseStoredObject):
|
||||
v.set_parent(key, self)
|
||||
# set item
|
||||
dict.__setitem__(self, key, v)
|
||||
if self.db and patch:
|
||||
op = 'add' if is_new else 'replace'
|
||||
self.db.add_patch({'op': op, 'path': key_path(self.path, key), 'value': v})
|
||||
self.db_add(key, v) if is_new else self.db_replace(key, v)
|
||||
|
||||
@locked
|
||||
def __delitem__(self, key):
|
||||
dict.__delitem__(self, key)
|
||||
if self.db:
|
||||
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
|
||||
self.db_remove(key)
|
||||
|
||||
@locked
|
||||
def pop(self, key, v=_RaiseKeyError):
|
||||
@@ -202,8 +222,9 @@ class StoredDict(dict):
|
||||
else:
|
||||
return v
|
||||
r = dict.pop(self, key)
|
||||
if self.db:
|
||||
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
|
||||
self.db_remove(key)
|
||||
if isinstance(r, StoredDict):
|
||||
r._parent = None
|
||||
return r
|
||||
|
||||
def setdefault(self, key, default = None, /):
|
||||
@@ -212,33 +233,28 @@ class StoredDict(dict):
|
||||
return self[key]
|
||||
|
||||
|
||||
class StoredList(list):
|
||||
class StoredList(list, BaseStoredObject):
|
||||
|
||||
def __init__(self, data, db: 'JsonDB', path):
|
||||
def __init__(self, data, db: 'JsonDB'):
|
||||
list.__init__(self, data)
|
||||
self.db = db
|
||||
self.lock = self.db.lock if self.db else threading.RLock()
|
||||
self.path = path
|
||||
self.set_db(db)
|
||||
|
||||
@locked
|
||||
def append(self, item):
|
||||
n = len(self)
|
||||
list.append(self, item)
|
||||
if self.db:
|
||||
self.db.add_patch({'op': 'add', 'path': key_path(self.path, '%d'%n), 'value':item})
|
||||
self.db_add('%d'%n, item)
|
||||
|
||||
@locked
|
||||
def remove(self, item):
|
||||
n = self.index(item)
|
||||
list.remove(self, item)
|
||||
if self.db:
|
||||
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, '%d'%n)})
|
||||
self.db_remove('%d'%n)
|
||||
|
||||
@locked
|
||||
def clear(self):
|
||||
list.clear(self)
|
||||
if self.db:
|
||||
self.db.add_patch({'op': 'replace', 'path': key_path(self.path, None), 'value':[]})
|
||||
self.db_replace(None, [])
|
||||
|
||||
|
||||
|
||||
@@ -263,8 +279,11 @@ class JsonDB(Logger):
|
||||
if upgrader:
|
||||
data, was_upgraded = upgrader(data)
|
||||
self._modified |= was_upgraded
|
||||
# convert to StoredDict
|
||||
self.data = StoredDict(data, self, [])
|
||||
# convert json to python objects
|
||||
data = self._convert_dict([], data)
|
||||
# convert dict to StoredDict
|
||||
self.data = StoredDict(data, self)
|
||||
self.data.set_parent('', None)
|
||||
# write file in case there was a db upgrade
|
||||
if self.storage and self.storage.file_exists():
|
||||
self.write_and_force_consolidation()
|
||||
@@ -338,6 +357,15 @@ class JsonDB(Logger):
|
||||
self.pending_changes.append(json.dumps(patch, cls=self.encoder))
|
||||
self.set_modified(True)
|
||||
|
||||
def add(self, path, key, value):
|
||||
self.add_patch({'op': 'add', 'path': key_path(path, key), 'value': value})
|
||||
|
||||
def replace(self, path, key, value):
|
||||
self.add_patch({'op': 'replace', 'path': key_path(path, key), 'value': value})
|
||||
|
||||
def remove(self, path, key):
|
||||
self.add_patch({'op': 'remove', 'path': key_path(path, key)})
|
||||
|
||||
@locked
|
||||
def get(self, key, default=None):
|
||||
v = self.data.get(key)
|
||||
@@ -391,7 +419,22 @@ class JsonDB(Logger):
|
||||
def _should_convert_to_stored_dict(self, key) -> bool:
|
||||
return True
|
||||
|
||||
def _convert_dict(self, path, key, v):
|
||||
def _convert_dict_key(self, path):
|
||||
key = path[-1]
|
||||
parent_key = path[-2] if len(path) > 1 else None
|
||||
gp_key = path[-3] if len(path) > 2 else None
|
||||
if parent_key and parent_key in registered_dict_keys:
|
||||
convert_key = registered_dict_keys[parent_key]
|
||||
elif gp_key and gp_key in registered_parent_keys:
|
||||
convert_key = registered_parent_keys.get(gp_key)
|
||||
else:
|
||||
convert_key = None
|
||||
if convert_key:
|
||||
key = convert_key(key)
|
||||
return key
|
||||
|
||||
def _convert_dict_value(self, path, v):
|
||||
key = path[-1]
|
||||
if key in registered_dicts:
|
||||
constructor, _type = registered_dicts[key]
|
||||
if _type == dict:
|
||||
@@ -400,25 +443,26 @@ class JsonDB(Logger):
|
||||
v = dict((k, constructor(*x)) for k, x in v.items())
|
||||
else:
|
||||
v = dict((k, constructor(x)) for k, x in v.items())
|
||||
if key in registered_dict_keys:
|
||||
convert_key = registered_dict_keys[key]
|
||||
elif path and path[-1] in registered_parent_keys:
|
||||
convert_key = registered_parent_keys.get(path[-1])
|
||||
else:
|
||||
convert_key = None
|
||||
if convert_key:
|
||||
v = dict((convert_key(k), x) for k, x in v.items())
|
||||
return v
|
||||
|
||||
def _convert_value(self, path, key, v):
|
||||
if key in registered_names:
|
||||
elif key in registered_names:
|
||||
constructor, _type = registered_names[key]
|
||||
if _type == dict:
|
||||
v = constructor(**v)
|
||||
else:
|
||||
v = constructor(v)
|
||||
if isinstance(v, dict):
|
||||
v = self._convert_dict(path, v)
|
||||
return v
|
||||
|
||||
def _convert_dict(self, path, data: dict):
|
||||
# recursively convert dict to StoredDict
|
||||
d = {}
|
||||
for k, v in list(data.items()):
|
||||
child_path = path + [k]
|
||||
k = self._convert_dict_key(child_path)
|
||||
v = self._convert_dict_value(child_path, v)
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
@locked
|
||||
def write(self):
|
||||
if self.storage.should_do_full_write_next():
|
||||
|
||||
@@ -773,7 +773,7 @@ class Channel(AbstractChannel):
|
||||
Logger.__init__(self) # should be after short_channel_id is set
|
||||
self.lnworker = lnworker
|
||||
self.storage = state
|
||||
self.db_lock = self.storage.db.lock if self.storage.db else threading.RLock()
|
||||
self.db_lock = self.storage.lock
|
||||
self.config = {}
|
||||
self.config[LOCAL] = state["local_config"]
|
||||
self.config[REMOTE] = state["remote_config"]
|
||||
|
||||
@@ -1263,8 +1263,7 @@ class Peer(Logger, EventListener):
|
||||
"revocation_store": {},
|
||||
"channel_type": channel_type,
|
||||
}
|
||||
# set db to None, because we do not want to write updates until channel is saved
|
||||
return StoredDict(chan_dict, None, [])
|
||||
return StoredDict(chan_dict, self.lnworker.db)
|
||||
|
||||
@non_blocking_msg_handler
|
||||
async def on_open_channel(self, payload):
|
||||
|
||||
@@ -1279,7 +1279,7 @@ def upgrade_wallet_db(data: dict, do_upgrade: bool) -> Tuple[dict, bool]:
|
||||
first_electrum_version_used=ELECTRUM_VERSION,
|
||||
)
|
||||
assert data.get("db_metadata", None) is None
|
||||
data["db_metadata"] = v
|
||||
data["db_metadata"] = v.to_json()
|
||||
was_upgraded = True
|
||||
|
||||
dbu = WalletDBUpgrader(data)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import traceback
|
||||
import json
|
||||
|
||||
import jsonpatch
|
||||
from jsonpatch import JsonPatchException
|
||||
@@ -8,6 +9,7 @@ from jsonpointer import JsonPointerException
|
||||
|
||||
from . import ElectrumTestCase
|
||||
|
||||
from electrum.json_db import JsonDB
|
||||
|
||||
class TestJsonpatch(ElectrumTestCase):
|
||||
|
||||
@@ -84,3 +86,50 @@ class TestJsonpatch(ElectrumTestCase):
|
||||
with self._customAssertRaises(JsonPointerException) as ctx:
|
||||
data2 = jpatch.apply(data1)
|
||||
fail_if_leaking_secret(ctx)
|
||||
|
||||
|
||||
class TestJsonDB(ElectrumTestCase):
|
||||
|
||||
async def test_jsonpatch_replace_after_remove(self):
|
||||
data = { 'a':{} }
|
||||
# op "add"
|
||||
patches = [{"op": "add", "path": "/a/b", "value": "42"}]
|
||||
jpatch = jsonpatch.JsonPatch(patches)
|
||||
data = jpatch.apply(data)
|
||||
# remove
|
||||
patches = [{"op": "remove", "path": "/a/b"}]
|
||||
jpatch = jsonpatch.JsonPatch(patches)
|
||||
data = jpatch.apply(data)
|
||||
# replace
|
||||
patches = [{"op": "replace", "path": "/a/b", "value": "43"}]
|
||||
jpatch = jsonpatch.JsonPatch(patches)
|
||||
with self.assertRaises(JsonPatchException):
|
||||
data = jpatch.apply(data)
|
||||
|
||||
async def test_jsondb_replace_after_remove(self):
|
||||
data = { 'a': {'b': {'c': 0}}}
|
||||
db = JsonDB(repr(data))
|
||||
a = db.get_dict('a')
|
||||
# remove
|
||||
b = a.pop('b')
|
||||
self.assertEqual(len(db.pending_changes), 1)
|
||||
# replace item. this must not been written to db
|
||||
b['c'] = 42
|
||||
self.assertEqual(len(db.pending_changes), 1)
|
||||
patches = json.loads('[' + ','.join(db.pending_changes) + ']')
|
||||
jpatch = jsonpatch.JsonPatch(patches)
|
||||
data = jpatch.apply(data)
|
||||
|
||||
async def test_jsondb_replace_after_remove_nested(self):
|
||||
data = { 'a': {'b':{'c':0}}}
|
||||
db = JsonDB(repr(data))
|
||||
# remove
|
||||
a = db.data.pop('a')
|
||||
self.assertEqual(len(db.pending_changes), 1)
|
||||
b = a['b']
|
||||
# replace item. this must not be written to db
|
||||
b['c'] = 42
|
||||
self.assertEqual(len(db.pending_changes), 1)
|
||||
patches = json.loads('[' + ','.join(db.pending_changes) + ']')
|
||||
jpatch = jsonpatch.JsonPatch(patches)
|
||||
data = jpatch.apply(data)
|
||||
|
||||
@@ -118,7 +118,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
|
||||
'revocation_store': {},
|
||||
'channel_type': channel_type,
|
||||
}
|
||||
return StoredDict(state, None, [])
|
||||
return StoredDict(state, None)
|
||||
|
||||
|
||||
def bip32(sequence):
|
||||
|
||||
@@ -14,8 +14,8 @@ class H(NamedTuple):
|
||||
|
||||
class TestHTLCManager(ElectrumTestCase):
|
||||
def test_adding_htlcs_race(self):
|
||||
A = HTLCManager(StoredDict({}, None, []))
|
||||
B = HTLCManager(StoredDict({}, None, []))
|
||||
A = HTLCManager(StoredDict({}, None))
|
||||
B = HTLCManager(StoredDict({}, None))
|
||||
A.channel_open_finished()
|
||||
B.channel_open_finished()
|
||||
ah0, bh0 = H('A', 0), H('B', 0)
|
||||
@@ -61,8 +61,8 @@ class TestHTLCManager(ElectrumTestCase):
|
||||
|
||||
def test_single_htlc_full_lifecycle(self):
|
||||
def htlc_lifecycle(htlc_success: bool):
|
||||
A = HTLCManager(StoredDict({}, None, []))
|
||||
B = HTLCManager(StoredDict({}, None, []))
|
||||
A = HTLCManager(StoredDict({}, None))
|
||||
B = HTLCManager(StoredDict({}, None))
|
||||
A.channel_open_finished()
|
||||
B.channel_open_finished()
|
||||
B.recv_htlc(A.send_htlc(H('A', 0)))
|
||||
@@ -134,8 +134,8 @@ class TestHTLCManager(ElectrumTestCase):
|
||||
|
||||
def test_remove_htlc_while_owing_commitment(self):
|
||||
def htlc_lifecycle(htlc_success: bool):
|
||||
A = HTLCManager(StoredDict({}, None, []))
|
||||
B = HTLCManager(StoredDict({}, None, []))
|
||||
A = HTLCManager(StoredDict({}, None))
|
||||
B = HTLCManager(StoredDict({}, None))
|
||||
A.channel_open_finished()
|
||||
B.channel_open_finished()
|
||||
ah0 = H('A', 0)
|
||||
@@ -171,8 +171,8 @@ class TestHTLCManager(ElectrumTestCase):
|
||||
htlc_lifecycle(htlc_success=False)
|
||||
|
||||
def test_adding_htlc_between_send_ctx_and_recv_rev(self):
|
||||
A = HTLCManager(StoredDict({}, None, []))
|
||||
B = HTLCManager(StoredDict({}, None, []))
|
||||
A = HTLCManager(StoredDict({}, None))
|
||||
B = HTLCManager(StoredDict({}, None))
|
||||
A.channel_open_finished()
|
||||
B.channel_open_finished()
|
||||
A.send_ctx()
|
||||
@@ -217,8 +217,8 @@ class TestHTLCManager(ElectrumTestCase):
|
||||
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
|
||||
|
||||
def test_unacked_local_updates(self):
|
||||
A = HTLCManager(StoredDict({}, None, []))
|
||||
B = HTLCManager(StoredDict({}, None, []))
|
||||
A = HTLCManager(StoredDict({}, None))
|
||||
B = HTLCManager(StoredDict({}, None))
|
||||
A.channel_open_finished()
|
||||
B.channel_open_finished()
|
||||
self.assertEqual({}, A.get_unacked_local_updates())
|
||||
|
||||
@@ -474,7 +474,7 @@ class TestLNUtil(ElectrumTestCase):
|
||||
]
|
||||
|
||||
for test in tests:
|
||||
receiver = RevocationStore(StoredDict({}, None, []))
|
||||
receiver = RevocationStore(StoredDict({}, None))
|
||||
for insert in test["inserts"]:
|
||||
secret = bytes.fromhex(insert["secret"])
|
||||
|
||||
@@ -497,7 +497,7 @@ class TestLNUtil(ElectrumTestCase):
|
||||
|
||||
def test_shachain_produce_consume(self):
|
||||
seed = bitcoin.sha256(b"shachaintest")
|
||||
consumer = RevocationStore(StoredDict({}, None, []))
|
||||
consumer = RevocationStore(StoredDict({}, None))
|
||||
for i in range(10000):
|
||||
secret = get_per_commitment_secret_from_seed(seed, RevocationStore.START_INDEX - i)
|
||||
try:
|
||||
@@ -507,7 +507,7 @@ class TestLNUtil(ElectrumTestCase):
|
||||
if i % 1000 == 0:
|
||||
c1 = consumer
|
||||
s1 = json.dumps(c1.storage, cls=MyEncoder)
|
||||
c2 = RevocationStore(StoredDict(json.loads(s1), None, []))
|
||||
c2 = RevocationStore(StoredDict(json.loads(s1), None))
|
||||
s2 = json.dumps(c2.storage, cls=MyEncoder)
|
||||
self.assertEqual(s1, s2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user