1
0

Merge pull request #10233 from spesmilo/jsondb_pointers

Jsondb pointers
This commit is contained in:
ThomasV
2025-11-07 10:21:27 +01:00
committed by GitHub
8 changed files with 191 additions and 99 deletions

View File

@@ -108,91 +108,111 @@ def key_path(path: Sequence[Union[str, int]], key: Optional[str]) -> str:
items = [to_str(x) for x in path]
if key is not None:
items.append(to_str(key))
return '/' + '/'.join(items)
return '/'.join(items)
class BaseStoredObject:
_db: 'JsonDB' = None
_key = None
_parent = None
_lock = None
def set_db(self, db):
self._db = db
self._lock = self._db.lock if self._db else threading.RLock()
def set_parent(self, key, parent):
self._key = key
self._parent = parent
@property
def lock(self):
return self._lock
@property
def path(self) -> Sequence[str]:
# return None iff we are pruned from root
x = self
s = [x._key]
while x._parent is not None:
x = x._parent
s = [x._key] + s
if x._key != '':
return None
assert self._db is not None
return s
def db_add(self, key, value):
if self.path:
self._db.add(self.path, key, value)
def db_replace(self, key, value):
if self.path:
self._db.replace(self.path, key, value)
def db_remove(self, key):
if self.path:
self._db.remove(self.path, key)
class StoredObject:
db: 'JsonDB' = None
path = None
class StoredObject(BaseStoredObject):
"""for attr.s objects """
def __setattr__(self, key, value):
if self.db and key not in ['path', 'db'] and not key.startswith('_'):
if self.path and not key.startswith('_'):
if value != getattr(self, key):
self.db.add_patch({'op': 'replace', 'path': key_path(self.path, key), 'value': value})
self.db_replace(key, value)
object.__setattr__(self, key, value)
def set_db(self, db, path):
self.db = db
self.path = path
def to_json(self):
d = dict(vars(self))
d.pop('db', None)
d.pop('path', None)
# don't expose/store private stuff
d = {k: v for k, v in d.items()
if not k.startswith('_')}
return d
_RaiseKeyError = object() # singleton for no-default behavior
class StoredDict(dict):
def __init__(self, data, db: 'JsonDB', path):
self.db = db
self.lock = self.db.lock if self.db else threading.RLock()
self.path = path
class StoredDict(dict, BaseStoredObject):
def __init__(self, data: dict, db: 'JsonDB'):
self.set_db(db)
# recursively convert dicts to StoredDict
for k, v in list(data.items()):
self.__setitem__(k, v, patch=False)
self.__setitem__(k, v)
@locked
def __setitem__(self, key, v, patch=True):
def __setitem__(self, key, v):
is_new = key not in self
# early return to prevent unnecessary disk writes
if not is_new and patch:
if self.db and json.dumps(v, cls=self.db.encoder) == json.dumps(self[key], cls=self.db.encoder):
return
# recursively set db and path
if isinstance(v, StoredDict):
#assert v.db is None
v.db = self.db
v.path = self.path + [key]
for k, vv in v.items():
v.__setitem__(k, vv, patch=False)
# recursively convert dict to StoredDict.
# _convert_dict is called breadth-first
elif isinstance(v, dict):
if self.db:
v = self.db._convert_dict(self.path, key, v)
if not self.db or self.db._should_convert_to_stored_dict(key):
v = StoredDict(v, self.db, self.path + [key])
# convert_value is called depth-first
if isinstance(v, dict) or isinstance(v, str) or isinstance(v, int):
if self.db:
v = self.db._convert_value(self.path, key, v)
# set parent of StoredObject
if isinstance(v, StoredObject):
v.set_db(self.db, self.path + [key])
# convert lists
if isinstance(v, list):
v = StoredList(v, self.db, self.path + [key])
if not is_new and self._db and json.dumps(v, cls=self._db.encoder) == json.dumps(self[key], cls=self._db.encoder):
return
# convert dict to StoredDict.
if type(v) == dict and (self._db is None or self._db._should_convert_to_stored_dict(key)):
v = StoredDict(v, self._db)
# convert list to StoredList
elif type(v) == list:
v = StoredList(v, self._db)
# reject sets. they do not work well with jsonpatch
if isinstance(v, set):
elif isinstance(v, set):
raise Exception(f"Do not store sets inside jsondb. path={self.path!r}")
# set db for StoredObject, because it is not set in the constructor
if isinstance(v, StoredObject):
v.set_db(self._db)
# set parent
if isinstance(v, BaseStoredObject):
v.set_parent(key, self)
# set item
dict.__setitem__(self, key, v)
if self.db and patch:
op = 'add' if is_new else 'replace'
self.db.add_patch({'op': op, 'path': key_path(self.path, key), 'value': v})
self.db_add(key, v) if is_new else self.db_replace(key, v)
@locked
def __delitem__(self, key):
dict.__delitem__(self, key)
if self.db:
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
self.db_remove(key)
@locked
def pop(self, key, v=_RaiseKeyError):
@@ -202,8 +222,9 @@ class StoredDict(dict):
else:
return v
r = dict.pop(self, key)
if self.db:
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
self.db_remove(key)
if isinstance(r, StoredDict):
r._parent = None
return r
def setdefault(self, key, default = None, /):
@@ -212,33 +233,28 @@ class StoredDict(dict):
return self[key]
class StoredList(list):
class StoredList(list, BaseStoredObject):
def __init__(self, data, db: 'JsonDB', path):
def __init__(self, data, db: 'JsonDB'):
list.__init__(self, data)
self.db = db
self.lock = self.db.lock if self.db else threading.RLock()
self.path = path
self.set_db(db)
@locked
def append(self, item):
n = len(self)
list.append(self, item)
if self.db:
self.db.add_patch({'op': 'add', 'path': key_path(self.path, '%d'%n), 'value':item})
self.db_add('%d'%n, item)
@locked
def remove(self, item):
n = self.index(item)
list.remove(self, item)
if self.db:
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, '%d'%n)})
self.db_remove('%d'%n)
@locked
def clear(self):
list.clear(self)
if self.db:
self.db.add_patch({'op': 'replace', 'path': key_path(self.path, None), 'value':[]})
self.db_replace(None, [])
@@ -263,8 +279,11 @@ class JsonDB(Logger):
if upgrader:
data, was_upgraded = upgrader(data)
self._modified |= was_upgraded
# convert to StoredDict
self.data = StoredDict(data, self, [])
# convert json to python objects
data = self._convert_dict([], data)
# convert dict to StoredDict
self.data = StoredDict(data, self)
self.data.set_parent('', None)
# write file in case there was a db upgrade
if self.storage and self.storage.file_exists():
self.write_and_force_consolidation()
@@ -338,6 +357,15 @@ class JsonDB(Logger):
self.pending_changes.append(json.dumps(patch, cls=self.encoder))
self.set_modified(True)
def add(self, path, key, value):
self.add_patch({'op': 'add', 'path': key_path(path, key), 'value': value})
def replace(self, path, key, value):
self.add_patch({'op': 'replace', 'path': key_path(path, key), 'value': value})
def remove(self, path, key):
self.add_patch({'op': 'remove', 'path': key_path(path, key)})
@locked
def get(self, key, default=None):
v = self.data.get(key)
@@ -391,7 +419,22 @@ class JsonDB(Logger):
def _should_convert_to_stored_dict(self, key) -> bool:
return True
def _convert_dict(self, path, key, v):
def _convert_dict_key(self, path):
key = path[-1]
parent_key = path[-2] if len(path) > 1 else None
gp_key = path[-3] if len(path) > 2 else None
if parent_key and parent_key in registered_dict_keys:
convert_key = registered_dict_keys[parent_key]
elif gp_key and gp_key in registered_parent_keys:
convert_key = registered_parent_keys.get(gp_key)
else:
convert_key = None
if convert_key:
key = convert_key(key)
return key
def _convert_dict_value(self, path, v):
key = path[-1]
if key in registered_dicts:
constructor, _type = registered_dicts[key]
if _type == dict:
@@ -400,25 +443,26 @@ class JsonDB(Logger):
v = dict((k, constructor(*x)) for k, x in v.items())
else:
v = dict((k, constructor(x)) for k, x in v.items())
if key in registered_dict_keys:
convert_key = registered_dict_keys[key]
elif path and path[-1] in registered_parent_keys:
convert_key = registered_parent_keys.get(path[-1])
else:
convert_key = None
if convert_key:
v = dict((convert_key(k), x) for k, x in v.items())
return v
def _convert_value(self, path, key, v):
if key in registered_names:
elif key in registered_names:
constructor, _type = registered_names[key]
if _type == dict:
v = constructor(**v)
else:
v = constructor(v)
if isinstance(v, dict):
v = self._convert_dict(path, v)
return v
def _convert_dict(self, path, data: dict):
# recursively convert dict to StoredDict
d = {}
for k, v in list(data.items()):
child_path = path + [k]
k = self._convert_dict_key(child_path)
v = self._convert_dict_value(child_path, v)
d[k] = v
return d
@locked
def write(self):
if self.storage.should_do_full_write_next():

View File

@@ -773,7 +773,7 @@ class Channel(AbstractChannel):
Logger.__init__(self) # should be after short_channel_id is set
self.lnworker = lnworker
self.storage = state
self.db_lock = self.storage.db.lock if self.storage.db else threading.RLock()
self.db_lock = self.storage.lock
self.config = {}
self.config[LOCAL] = state["local_config"]
self.config[REMOTE] = state["remote_config"]

View File

@@ -1263,8 +1263,7 @@ class Peer(Logger, EventListener):
"revocation_store": {},
"channel_type": channel_type,
}
# set db to None, because we do not want to write updates until channel is saved
return StoredDict(chan_dict, None, [])
return StoredDict(chan_dict, self.lnworker.db)
@non_blocking_msg_handler
async def on_open_channel(self, payload):

View File

@@ -1279,7 +1279,7 @@ def upgrade_wallet_db(data: dict, do_upgrade: bool) -> Tuple[dict, bool]:
first_electrum_version_used=ELECTRUM_VERSION,
)
assert data.get("db_metadata", None) is None
data["db_metadata"] = v
data["db_metadata"] = v.to_json()
was_upgraded = True
dbu = WalletDBUpgrader(data)

View File

@@ -1,6 +1,7 @@
import contextlib
import copy
import traceback
import json
import jsonpatch
from jsonpatch import JsonPatchException
@@ -8,6 +9,7 @@ from jsonpointer import JsonPointerException
from . import ElectrumTestCase
from electrum.json_db import JsonDB
class TestJsonpatch(ElectrumTestCase):
@@ -84,3 +86,50 @@ class TestJsonpatch(ElectrumTestCase):
with self._customAssertRaises(JsonPointerException) as ctx:
data2 = jpatch.apply(data1)
fail_if_leaking_secret(ctx)
class TestJsonDB(ElectrumTestCase):
async def test_jsonpatch_replace_after_remove(self):
data = { 'a':{} }
# op "add"
patches = [{"op": "add", "path": "/a/b", "value": "42"}]
jpatch = jsonpatch.JsonPatch(patches)
data = jpatch.apply(data)
# remove
patches = [{"op": "remove", "path": "/a/b"}]
jpatch = jsonpatch.JsonPatch(patches)
data = jpatch.apply(data)
# replace
patches = [{"op": "replace", "path": "/a/b", "value": "43"}]
jpatch = jsonpatch.JsonPatch(patches)
with self.assertRaises(JsonPatchException):
data = jpatch.apply(data)
async def test_jsondb_replace_after_remove(self):
data = { 'a': {'b': {'c': 0}}}
db = JsonDB(repr(data))
a = db.get_dict('a')
# remove
b = a.pop('b')
self.assertEqual(len(db.pending_changes), 1)
# replace item. this must not been written to db
b['c'] = 42
self.assertEqual(len(db.pending_changes), 1)
patches = json.loads('[' + ','.join(db.pending_changes) + ']')
jpatch = jsonpatch.JsonPatch(patches)
data = jpatch.apply(data)
async def test_jsondb_replace_after_remove_nested(self):
data = { 'a': {'b':{'c':0}}}
db = JsonDB(repr(data))
# remove
a = db.data.pop('a')
self.assertEqual(len(db.pending_changes), 1)
b = a['b']
# replace item. this must not be written to db
b['c'] = 42
self.assertEqual(len(db.pending_changes), 1)
patches = json.loads('[' + ','.join(db.pending_changes) + ']')
jpatch = jsonpatch.JsonPatch(patches)
data = jpatch.apply(data)

View File

@@ -118,7 +118,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
'revocation_store': {},
'channel_type': channel_type,
}
return StoredDict(state, None, [])
return StoredDict(state, None)
def bip32(sequence):

View File

@@ -14,8 +14,8 @@ class H(NamedTuple):
class TestHTLCManager(ElectrumTestCase):
def test_adding_htlcs_race(self):
A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager(StoredDict({}, None, []))
A = HTLCManager(StoredDict({}, None))
B = HTLCManager(StoredDict({}, None))
A.channel_open_finished()
B.channel_open_finished()
ah0, bh0 = H('A', 0), H('B', 0)
@@ -61,8 +61,8 @@ class TestHTLCManager(ElectrumTestCase):
def test_single_htlc_full_lifecycle(self):
def htlc_lifecycle(htlc_success: bool):
A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager(StoredDict({}, None, []))
A = HTLCManager(StoredDict({}, None))
B = HTLCManager(StoredDict({}, None))
A.channel_open_finished()
B.channel_open_finished()
B.recv_htlc(A.send_htlc(H('A', 0)))
@@ -134,8 +134,8 @@ class TestHTLCManager(ElectrumTestCase):
def test_remove_htlc_while_owing_commitment(self):
def htlc_lifecycle(htlc_success: bool):
A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager(StoredDict({}, None, []))
A = HTLCManager(StoredDict({}, None))
B = HTLCManager(StoredDict({}, None))
A.channel_open_finished()
B.channel_open_finished()
ah0 = H('A', 0)
@@ -171,8 +171,8 @@ class TestHTLCManager(ElectrumTestCase):
htlc_lifecycle(htlc_success=False)
def test_adding_htlc_between_send_ctx_and_recv_rev(self):
A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager(StoredDict({}, None, []))
A = HTLCManager(StoredDict({}, None))
B = HTLCManager(StoredDict({}, None))
A.channel_open_finished()
B.channel_open_finished()
A.send_ctx()
@@ -217,8 +217,8 @@ class TestHTLCManager(ElectrumTestCase):
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
def test_unacked_local_updates(self):
A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager(StoredDict({}, None, []))
A = HTLCManager(StoredDict({}, None))
B = HTLCManager(StoredDict({}, None))
A.channel_open_finished()
B.channel_open_finished()
self.assertEqual({}, A.get_unacked_local_updates())

View File

@@ -474,7 +474,7 @@ class TestLNUtil(ElectrumTestCase):
]
for test in tests:
receiver = RevocationStore(StoredDict({}, None, []))
receiver = RevocationStore(StoredDict({}, None))
for insert in test["inserts"]:
secret = bytes.fromhex(insert["secret"])
@@ -497,7 +497,7 @@ class TestLNUtil(ElectrumTestCase):
def test_shachain_produce_consume(self):
seed = bitcoin.sha256(b"shachaintest")
consumer = RevocationStore(StoredDict({}, None, []))
consumer = RevocationStore(StoredDict({}, None))
for i in range(10000):
secret = get_per_commitment_secret_from_seed(seed, RevocationStore.START_INDEX - i)
try:
@@ -507,7 +507,7 @@ class TestLNUtil(ElectrumTestCase):
if i % 1000 == 0:
c1 = consumer
s1 = json.dumps(c1.storage, cls=MyEncoder)
c2 = RevocationStore(StoredDict(json.loads(s1), None, []))
c2 = RevocationStore(StoredDict(json.loads(s1), None))
s2 = json.dumps(c2.storage, cls=MyEncoder)
self.assertEqual(s1, s2)