diff --git a/electrum/json_db.py b/electrum/json_db.py index a79f691ae..b49398338 100644 --- a/electrum/json_db.py +++ b/electrum/json_db.py @@ -108,91 +108,111 @@ def key_path(path: Sequence[Union[str, int]], key: Optional[str]) -> str: items = [to_str(x) for x in path] if key is not None: items.append(to_str(key)) - return '/' + '/'.join(items) + return '/'.join(items) + +class BaseStoredObject: + + _db: 'JsonDB' = None + _key = None + _parent = None + _lock = None + + def set_db(self, db): + self._db = db + self._lock = self._db.lock if self._db else threading.RLock() + + def set_parent(self, key, parent): + self._key = key + self._parent = parent + + @property + def lock(self): + return self._lock + + @property + def path(self) -> Sequence[str]: + # return None iff we are pruned from root + x = self + s = [x._key] + while x._parent is not None: + x = x._parent + s = [x._key] + s + if x._key != '': + return None + assert self._db is not None + return s + + def db_add(self, key, value): + if self.path: + self._db.add(self.path, key, value) + + def db_replace(self, key, value): + if self.path: + self._db.replace(self.path, key, value) + + def db_remove(self, key): + if self.path: + self._db.remove(self.path, key) -class StoredObject: - - db: 'JsonDB' = None - path = None +class StoredObject(BaseStoredObject): + """for attr.s objects """ def __setattr__(self, key, value): - if self.db and key not in ['path', 'db'] and not key.startswith('_'): + if self.path and not key.startswith('_'): if value != getattr(self, key): - self.db.add_patch({'op': 'replace', 'path': key_path(self.path, key), 'value': value}) + self.db_replace(key, value) object.__setattr__(self, key, value) - def set_db(self, db, path): - self.db = db - self.path = path - def to_json(self): d = dict(vars(self)) - d.pop('db', None) - d.pop('path', None) # don't expose/store private stuff d = {k: v for k, v in d.items() if not k.startswith('_')} return d + _RaiseKeyError = object() # singleton for no-default behavior -class StoredDict(dict): - def __init__(self, data, db: 'JsonDB', path): - self.db = db - self.lock = self.db.lock if self.db else threading.RLock() - self.path = path +class StoredDict(dict, BaseStoredObject): + + def __init__(self, data: dict, db: 'JsonDB'): + self.set_db(db) # recursively convert dicts to StoredDict for k, v in list(data.items()): - self.__setitem__(k, v, patch=False) + self.__setitem__(k, v) @locked - def __setitem__(self, key, v, patch=True): + def __setitem__(self, key, v): is_new = key not in self # early return to prevent unnecessary disk writes - if not is_new and patch: - if self.db and json.dumps(v, cls=self.db.encoder) == json.dumps(self[key], cls=self.db.encoder): - return - # recursively set db and path - if isinstance(v, StoredDict): - #assert v.db is None - v.db = self.db - v.path = self.path + [key] - for k, vv in v.items(): - v.__setitem__(k, vv, patch=False) - # recursively convert dict to StoredDict. - # _convert_dict is called breadth-first - elif isinstance(v, dict): - if self.db: - v = self.db._convert_dict(self.path, key, v) - if not self.db or self.db._should_convert_to_stored_dict(key): - v = StoredDict(v, self.db, self.path + [key]) - # convert_value is called depth-first - if isinstance(v, dict) or isinstance(v, str) or isinstance(v, int): - if self.db: - v = self.db._convert_value(self.path, key, v) - # set parent of StoredObject - if isinstance(v, StoredObject): - v.set_db(self.db, self.path + [key]) - # convert lists - if isinstance(v, list): - v = StoredList(v, self.db, self.path + [key]) + if not is_new and self._db and json.dumps(v, cls=self._db.encoder) == json.dumps(self[key], cls=self._db.encoder): + return + # convert dict to StoredDict. + if type(v) == dict and (self._db is None or self._db._should_convert_to_stored_dict(key)): + v = StoredDict(v, self._db) + # convert list to StoredList + elif type(v) == list: + v = StoredList(v, self._db) # reject sets. they do not work well with jsonpatch - if isinstance(v, set): + elif isinstance(v, set): raise Exception(f"Do not store sets inside jsondb. path={self.path!r}") + # set db for StoredObject, because it is not set in the constructor + if isinstance(v, StoredObject): + v.set_db(self._db) + # set parent + if isinstance(v, BaseStoredObject): + v.set_parent(key, self) # set item dict.__setitem__(self, key, v) - if self.db and patch: - op = 'add' if is_new else 'replace' - self.db.add_patch({'op': op, 'path': key_path(self.path, key), 'value': v}) + self.db_add(key, v) if is_new else self.db_replace(key, v) @locked def __delitem__(self, key): dict.__delitem__(self, key) - if self.db: - self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)}) + self.db_remove(key) @locked def pop(self, key, v=_RaiseKeyError): @@ -202,8 +222,9 @@ class StoredDict(dict): else: return v r = dict.pop(self, key) - if self.db: - self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)}) + self.db_remove(key) + if isinstance(r, StoredDict): + r._parent = None return r def setdefault(self, key, default = None, /): @@ -212,33 +233,28 @@ class StoredDict(dict): return self[key] -class StoredList(list): +class StoredList(list, BaseStoredObject): - def __init__(self, data, db: 'JsonDB', path): + def __init__(self, data, db: 'JsonDB'): list.__init__(self, data) - self.db = db - self.lock = self.db.lock if self.db else threading.RLock() - self.path = path + self.set_db(db) @locked def append(self, item): n = len(self) list.append(self, item) - if self.db: - self.db.add_patch({'op': 'add', 'path': key_path(self.path, '%d'%n), 'value':item}) + self.db_add('%d'%n, item) @locked def remove(self, item): n = self.index(item) list.remove(self, item) - if self.db: - self.db.add_patch({'op': 'remove', 'path': key_path(self.path, '%d'%n)}) + self.db_remove('%d'%n) @locked def clear(self): list.clear(self) - if self.db: - self.db.add_patch({'op': 'replace', 'path': key_path(self.path, None), 'value':[]}) + self.db_replace(None, []) @@ -263,8 +279,11 @@ class JsonDB(Logger): if upgrader: data, was_upgraded = upgrader(data) self._modified |= was_upgraded - # convert to StoredDict - self.data = StoredDict(data, self, []) + # convert json to python objects + data = self._convert_dict([], data) + # convert dict to StoredDict + self.data = StoredDict(data, self) + self.data.set_parent('', None) # write file in case there was a db upgrade if self.storage and self.storage.file_exists(): self.write_and_force_consolidation() @@ -338,6 +357,15 @@ class JsonDB(Logger): self.pending_changes.append(json.dumps(patch, cls=self.encoder)) self.set_modified(True) + def add(self, path, key, value): + self.add_patch({'op': 'add', 'path': key_path(path, key), 'value': value}) + + def replace(self, path, key, value): + self.add_patch({'op': 'replace', 'path': key_path(path, key), 'value': value}) + + def remove(self, path, key): + self.add_patch({'op': 'remove', 'path': key_path(path, key)}) + @locked def get(self, key, default=None): v = self.data.get(key) @@ -391,7 +419,22 @@ class JsonDB(Logger): def _should_convert_to_stored_dict(self, key) -> bool: return True - def _convert_dict(self, path, key, v): + def _convert_dict_key(self, path): + key = path[-1] + parent_key = path[-2] if len(path) > 1 else None + gp_key = path[-3] if len(path) > 2 else None + if parent_key and parent_key in registered_dict_keys: + convert_key = registered_dict_keys[parent_key] + elif gp_key and gp_key in registered_parent_keys: + convert_key = registered_parent_keys.get(gp_key) + else: + convert_key = None + if convert_key: + key = convert_key(key) + return key + + def _convert_dict_value(self, path, v): + key = path[-1] if key in registered_dicts: constructor, _type = registered_dicts[key] if _type == dict: @@ -400,25 +443,26 @@ class JsonDB(Logger): v = dict((k, constructor(*x)) for k, x in v.items()) else: v = dict((k, constructor(x)) for k, x in v.items()) - if key in registered_dict_keys: - convert_key = registered_dict_keys[key] - elif path and path[-1] in registered_parent_keys: - convert_key = registered_parent_keys.get(path[-1]) - else: - convert_key = None - if convert_key: - v = dict((convert_key(k), x) for k, x in v.items()) - return v - - def _convert_value(self, path, key, v): - if key in registered_names: + elif key in registered_names: constructor, _type = registered_names[key] if _type == dict: v = constructor(**v) else: v = constructor(v) + if isinstance(v, dict): + v = self._convert_dict(path, v) return v + def _convert_dict(self, path, data: dict): + # recursively convert dict to StoredDict + d = {} + for k, v in list(data.items()): + child_path = path + [k] + k = self._convert_dict_key(child_path) + v = self._convert_dict_value(child_path, v) + d[k] = v + return d + @locked def write(self): if self.storage.should_do_full_write_next(): diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index 2d57a4a8b..dd5fedaca 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -773,7 +773,7 @@ class Channel(AbstractChannel): Logger.__init__(self) # should be after short_channel_id is set self.lnworker = lnworker self.storage = state - self.db_lock = self.storage.db.lock if self.storage.db else threading.RLock() + self.db_lock = self.storage.lock self.config = {} self.config[LOCAL] = state["local_config"] self.config[REMOTE] = state["remote_config"] diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 7424cb47a..39c6cf943 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1263,8 +1263,7 @@ class Peer(Logger, EventListener): "revocation_store": {}, "channel_type": channel_type, } - # set db to None, because we do not want to write updates until channel is saved - return StoredDict(chan_dict, None, []) + return StoredDict(chan_dict, self.lnworker.db) @non_blocking_msg_handler async def on_open_channel(self, payload): diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index 95d038e9a..586afb42e 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -1279,7 +1279,7 @@ def upgrade_wallet_db(data: dict, do_upgrade: bool) -> Tuple[dict, bool]: first_electrum_version_used=ELECTRUM_VERSION, ) assert data.get("db_metadata", None) is None - data["db_metadata"] = v + data["db_metadata"] = v.to_json() was_upgraded = True dbu = WalletDBUpgrader(data) diff --git a/tests/test_jsondb.py b/tests/test_jsondb.py index 16129e114..eeaf41079 100644 --- a/tests/test_jsondb.py +++ b/tests/test_jsondb.py @@ -1,6 +1,7 @@ import contextlib import copy import traceback +import json import jsonpatch from jsonpatch import JsonPatchException @@ -8,6 +9,7 @@ from jsonpointer import JsonPointerException from . import ElectrumTestCase +from electrum.json_db import JsonDB class TestJsonpatch(ElectrumTestCase): @@ -84,3 +86,50 @@ class TestJsonpatch(ElectrumTestCase): with self._customAssertRaises(JsonPointerException) as ctx: data2 = jpatch.apply(data1) fail_if_leaking_secret(ctx) + + +class TestJsonDB(ElectrumTestCase): + + async def test_jsonpatch_replace_after_remove(self): + data = { 'a':{} } + # op "add" + patches = [{"op": "add", "path": "/a/b", "value": "42"}] + jpatch = jsonpatch.JsonPatch(patches) + data = jpatch.apply(data) + # remove + patches = [{"op": "remove", "path": "/a/b"}] + jpatch = jsonpatch.JsonPatch(patches) + data = jpatch.apply(data) + # replace + patches = [{"op": "replace", "path": "/a/b", "value": "43"}] + jpatch = jsonpatch.JsonPatch(patches) + with self.assertRaises(JsonPatchException): + data = jpatch.apply(data) + + async def test_jsondb_replace_after_remove(self): + data = { 'a': {'b': {'c': 0}}} + db = JsonDB(repr(data)) + a = db.get_dict('a') + # remove + b = a.pop('b') + self.assertEqual(len(db.pending_changes), 1) + # replace item. this must not been written to db + b['c'] = 42 + self.assertEqual(len(db.pending_changes), 1) + patches = json.loads('[' + ','.join(db.pending_changes) + ']') + jpatch = jsonpatch.JsonPatch(patches) + data = jpatch.apply(data) + + async def test_jsondb_replace_after_remove_nested(self): + data = { 'a': {'b':{'c':0}}} + db = JsonDB(repr(data)) + # remove + a = db.data.pop('a') + self.assertEqual(len(db.pending_changes), 1) + b = a['b'] + # replace item. this must not be written to db + b['c'] = 42 + self.assertEqual(len(db.pending_changes), 1) + patches = json.loads('[' + ','.join(db.pending_changes) + ']') + jpatch = jsonpatch.JsonPatch(patches) + data = jpatch.apply(data) diff --git a/tests/test_lnchannel.py b/tests/test_lnchannel.py index 59d7f0024..2ba4c4092 100644 --- a/tests/test_lnchannel.py +++ b/tests/test_lnchannel.py @@ -118,7 +118,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator, 'revocation_store': {}, 'channel_type': channel_type, } - return StoredDict(state, None, []) + return StoredDict(state, None) def bip32(sequence): diff --git a/tests/test_lnhtlc.py b/tests/test_lnhtlc.py index 833052478..c8ac53db3 100644 --- a/tests/test_lnhtlc.py +++ b/tests/test_lnhtlc.py @@ -14,8 +14,8 @@ class H(NamedTuple): class TestHTLCManager(ElectrumTestCase): def test_adding_htlcs_race(self): - A = HTLCManager(StoredDict({}, None, [])) - B = HTLCManager(StoredDict({}, None, [])) + A = HTLCManager(StoredDict({}, None)) + B = HTLCManager(StoredDict({}, None)) A.channel_open_finished() B.channel_open_finished() ah0, bh0 = H('A', 0), H('B', 0) @@ -61,8 +61,8 @@ class TestHTLCManager(ElectrumTestCase): def test_single_htlc_full_lifecycle(self): def htlc_lifecycle(htlc_success: bool): - A = HTLCManager(StoredDict({}, None, [])) - B = HTLCManager(StoredDict({}, None, [])) + A = HTLCManager(StoredDict({}, None)) + B = HTLCManager(StoredDict({}, None)) A.channel_open_finished() B.channel_open_finished() B.recv_htlc(A.send_htlc(H('A', 0))) @@ -134,8 +134,8 @@ class TestHTLCManager(ElectrumTestCase): def test_remove_htlc_while_owing_commitment(self): def htlc_lifecycle(htlc_success: bool): - A = HTLCManager(StoredDict({}, None, [])) - B = HTLCManager(StoredDict({}, None, [])) + A = HTLCManager(StoredDict({}, None)) + B = HTLCManager(StoredDict({}, None)) A.channel_open_finished() B.channel_open_finished() ah0 = H('A', 0) @@ -171,8 +171,8 @@ class TestHTLCManager(ElectrumTestCase): htlc_lifecycle(htlc_success=False) def test_adding_htlc_between_send_ctx_and_recv_rev(self): - A = HTLCManager(StoredDict({}, None, [])) - B = HTLCManager(StoredDict({}, None, [])) + A = HTLCManager(StoredDict({}, None)) + B = HTLCManager(StoredDict({}, None)) A.channel_open_finished() B.channel_open_finished() A.send_ctx() @@ -217,8 +217,8 @@ class TestHTLCManager(ElectrumTestCase): self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE)) def test_unacked_local_updates(self): - A = HTLCManager(StoredDict({}, None, [])) - B = HTLCManager(StoredDict({}, None, [])) + A = HTLCManager(StoredDict({}, None)) + B = HTLCManager(StoredDict({}, None)) A.channel_open_finished() B.channel_open_finished() self.assertEqual({}, A.get_unacked_local_updates()) diff --git a/tests/test_lnutil.py b/tests/test_lnutil.py index 273e9d041..944d8a889 100644 --- a/tests/test_lnutil.py +++ b/tests/test_lnutil.py @@ -474,7 +474,7 @@ class TestLNUtil(ElectrumTestCase): ] for test in tests: - receiver = RevocationStore(StoredDict({}, None, [])) + receiver = RevocationStore(StoredDict({}, None)) for insert in test["inserts"]: secret = bytes.fromhex(insert["secret"]) @@ -497,7 +497,7 @@ class TestLNUtil(ElectrumTestCase): def test_shachain_produce_consume(self): seed = bitcoin.sha256(b"shachaintest") - consumer = RevocationStore(StoredDict({}, None, [])) + consumer = RevocationStore(StoredDict({}, None)) for i in range(10000): secret = get_per_commitment_secret_from_seed(seed, RevocationStore.START_INDEX - i) try: @@ -507,7 +507,7 @@ class TestLNUtil(ElectrumTestCase): if i % 1000 == 0: c1 = consumer s1 = json.dumps(c1.storage, cls=MyEncoder) - c2 = RevocationStore(StoredDict(json.loads(s1), None, [])) + c2 = RevocationStore(StoredDict(json.loads(s1), None)) s2 = json.dumps(c2.storage, cls=MyEncoder) self.assertEqual(s1, s2)