1
0

StoredDict: use pointers instead of path

Instead of storing its own path, each StoredDict element stores
its own key and a pointer to its parent. If a dict is removed
from the db, its parent pointer is set to None. This makes
self.path return None for all branches that have been pruned.

This passes tests/tests_json_db.py and fixes issue #10000
This commit is contained in:
ThomasV
2025-07-09 10:31:32 +02:00
parent 53c1817956
commit 077bcf515d
6 changed files with 109 additions and 77 deletions

View File

@@ -108,84 +108,111 @@ def key_path(path: Sequence[Union[str, int]], key: Optional[str]) -> str:
items = [to_str(x) for x in path]
if key is not None:
items.append(to_str(key))
return '/' + '/'.join(items)
return '/'.join(items)
class BaseStoredObject:
_db: 'JsonDB' = None
_key = None
_parent = None
_lock = None
def set_db(self, db):
self._db = db
self._lock = self._db.lock if self._db else threading.RLock()
def set_parent(self, key, parent):
self._key = key
self._parent = parent
@property
def lock(self):
return self._lock
@property
def path(self) -> Sequence[str]:
# return None iff we are pruned from root
x = self
s = [x._key]
while x._parent is not None:
x = x._parent
s = [x._key] + s
if x._key != '':
return None
assert self._db is not None
return s
def db_add(self, key, value):
if self.path:
self._db.add(self.path, key, value)
def db_replace(self, key, value):
if self.path:
self._db.replace(self.path, key, value)
def db_remove(self, key):
if self.path:
self._db.remove(self.path, key)
class StoredObject:
db: 'JsonDB' = None
path = None
class StoredObject(BaseStoredObject):
"""for attr.s objects """
def __setattr__(self, key, value):
if self.db and key not in ['path', 'db'] and not key.startswith('_'):
if self.path and not key.startswith('_'):
if value != getattr(self, key):
self.db.add_patch({'op': 'replace', 'path': key_path(self.path, key), 'value': value})
self.db_replace(key, value)
object.__setattr__(self, key, value)
def set_db(self, db, path):
self.db = db
self.path = path
def to_json(self):
d = dict(vars(self))
d.pop('db', None)
d.pop('path', None)
# don't expose/store private stuff
d = {k: v for k, v in d.items()
if not k.startswith('_')}
return d
_RaiseKeyError = object() # singleton for no-default behavior
class StoredDict(dict):
def __init__(self, data, db: 'JsonDB', path):
self.db = db
self.lock = self.db.lock if self.db else threading.RLock()
self.path = path
class StoredDict(dict, BaseStoredObject):
def __init__(self, data: dict, db: 'JsonDB'):
self.set_db(db)
# recursively convert dicts to StoredDict
for k, v in list(data.items()):
self.__setitem__(k, v, patch=False)
self.__setitem__(k, v)
@locked
def __setitem__(self, key, v, patch=True):
def __setitem__(self, key, v):
is_new = key not in self
# early return to prevent unnecessary disk writes
if not is_new and patch:
if self.db and json.dumps(v, cls=self.db.encoder) == json.dumps(self[key], cls=self.db.encoder):
return
# recursively set db and path
if isinstance(v, StoredDict):
#assert v.db is None
v.db = self.db
v.path = self.path + [key]
for k, vv in v.items():
v.__setitem__(k, vv, patch=False)
# recursively convert dict to StoredDict.
elif isinstance(v, dict):
if not self.db or self.db._should_convert_to_stored_dict(key):
v = StoredDict(v, self.db, self.path + [key])
# set parent of StoredObject
if isinstance(v, StoredObject):
v.set_db(self.db, self.path + [key])
# convert lists
if isinstance(v, list):
v = StoredList(v, self.db, self.path + [key])
if not is_new and self._db and json.dumps(v, cls=self._db.encoder) == json.dumps(self[key], cls=self._db.encoder):
return
# convert dict to StoredDict.
if type(v) == dict and (self._db is None or self._db._should_convert_to_stored_dict(key)):
v = StoredDict(v, self._db)
# convert list to StoredList
elif type(v) == list:
v = StoredList(v, self._db)
# reject sets. they do not work well with jsonpatch
if isinstance(v, set):
elif isinstance(v, set):
raise Exception(f"Do not store sets inside jsondb. path={self.path!r}")
# set db for StoredObject, because it is not set in the constructor
if isinstance(v, StoredObject):
v.set_db(self._db)
# set parent
if isinstance(v, BaseStoredObject):
v.set_parent(key, self)
# set item
dict.__setitem__(self, key, v)
if self.db and patch:
op = 'add' if is_new else 'replace'
self.db.add_patch({'op': op, 'path': key_path(self.path, key), 'value': v})
self.db_add(key, v) if is_new else self.db_replace(key, v)
@locked
def __delitem__(self, key):
dict.__delitem__(self, key)
if self.db:
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
self.db_remove(key)
@locked
def pop(self, key, v=_RaiseKeyError):
@@ -195,8 +222,9 @@ class StoredDict(dict):
else:
return v
r = dict.pop(self, key)
if self.db:
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
self.db_remove(key)
if isinstance(r, StoredDict):
r._parent = None
return r
def setdefault(self, key, default = None, /):
@@ -205,33 +233,28 @@ class StoredDict(dict):
return self[key]
class StoredList(list):
class StoredList(list, BaseStoredObject):
def __init__(self, data, db: 'JsonDB', path):
def __init__(self, data, db: 'JsonDB'):
list.__init__(self, data)
self.db = db
self.lock = self.db.lock if self.db else threading.RLock()
self.path = path
self.set_db(db)
@locked
def append(self, item):
n = len(self)
list.append(self, item)
if self.db:
self.db.add_patch({'op': 'add', 'path': key_path(self.path, '%d'%n), 'value':item})
self.db_add('%d'%n, item)
@locked
def remove(self, item):
n = self.index(item)
list.remove(self, item)
if self.db:
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, '%d'%n)})
self.db_remove('%d'%n)
@locked
def clear(self):
list.clear(self)
if self.db:
self.db.add_patch({'op': 'replace', 'path': key_path(self.path, None), 'value':[]})
self.db_replace(None, [])
@@ -259,7 +282,8 @@ class JsonDB(Logger):
# convert json to python objects
data = self._convert_dict([], data)
# convert dict to StoredDict
self.data = StoredDict(data, self, [])
self.data = StoredDict(data, self)
self.data.set_parent('', None)
# write file in case there was a db upgrade
if self.storage and self.storage.file_exists():
self.write_and_force_consolidation()
@@ -333,6 +357,15 @@ class JsonDB(Logger):
self.pending_changes.append(json.dumps(patch, cls=self.encoder))
self.set_modified(True)
def add(self, path, key, value):
self.add_patch({'op': 'add', 'path': key_path(path, key), 'value': value})
def replace(self, path, key, value):
self.add_patch({'op': 'replace', 'path': key_path(path, key), 'value': value})
def remove(self, path, key):
self.add_patch({'op': 'remove', 'path': key_path(path, key)})
@locked
def get(self, key, default=None):
v = self.data.get(key)

View File

@@ -773,7 +773,7 @@ class Channel(AbstractChannel):
Logger.__init__(self) # should be after short_channel_id is set
self.lnworker = lnworker
self.storage = state
self.db_lock = self.storage.db.lock if self.storage.db else threading.RLock()
self.db_lock = self.storage.lock
self.config = {}
self.config[LOCAL] = state["local_config"]
self.config[REMOTE] = state["remote_config"]

View File

@@ -1263,8 +1263,7 @@ class Peer(Logger, EventListener):
"revocation_store": {},
"channel_type": channel_type,
}
# set db to None, because we do not want to write updates until channel is saved
return StoredDict(chan_dict, None, [])
return StoredDict(chan_dict, self.lnworker.db)
@non_blocking_msg_handler
async def on_open_channel(self, payload):

View File

@@ -118,7 +118,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
'revocation_store': {},
'channel_type': channel_type,
}
return StoredDict(state, None, [])
return StoredDict(state, None)
def bip32(sequence):

View File

@@ -14,8 +14,8 @@ class H(NamedTuple):
class TestHTLCManager(ElectrumTestCase):
def test_adding_htlcs_race(self):
A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager(StoredDict({}, None, []))
A = HTLCManager(StoredDict({}, None))
B = HTLCManager(StoredDict({}, None))
A.channel_open_finished()
B.channel_open_finished()
ah0, bh0 = H('A', 0), H('B', 0)
@@ -61,8 +61,8 @@ class TestHTLCManager(ElectrumTestCase):
def test_single_htlc_full_lifecycle(self):
def htlc_lifecycle(htlc_success: bool):
A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager(StoredDict({}, None, []))
A = HTLCManager(StoredDict({}, None))
B = HTLCManager(StoredDict({}, None))
A.channel_open_finished()
B.channel_open_finished()
B.recv_htlc(A.send_htlc(H('A', 0)))
@@ -134,8 +134,8 @@ class TestHTLCManager(ElectrumTestCase):
def test_remove_htlc_while_owing_commitment(self):
def htlc_lifecycle(htlc_success: bool):
A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager(StoredDict({}, None, []))
A = HTLCManager(StoredDict({}, None))
B = HTLCManager(StoredDict({}, None))
A.channel_open_finished()
B.channel_open_finished()
ah0 = H('A', 0)
@@ -171,8 +171,8 @@ class TestHTLCManager(ElectrumTestCase):
htlc_lifecycle(htlc_success=False)
def test_adding_htlc_between_send_ctx_and_recv_rev(self):
A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager(StoredDict({}, None, []))
A = HTLCManager(StoredDict({}, None))
B = HTLCManager(StoredDict({}, None))
A.channel_open_finished()
B.channel_open_finished()
A.send_ctx()
@@ -217,8 +217,8 @@ class TestHTLCManager(ElectrumTestCase):
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
def test_unacked_local_updates(self):
A = HTLCManager(StoredDict({}, None, []))
B = HTLCManager(StoredDict({}, None, []))
A = HTLCManager(StoredDict({}, None))
B = HTLCManager(StoredDict({}, None))
A.channel_open_finished()
B.channel_open_finished()
self.assertEqual({}, A.get_unacked_local_updates())

View File

@@ -474,7 +474,7 @@ class TestLNUtil(ElectrumTestCase):
]
for test in tests:
receiver = RevocationStore(StoredDict({}, None, []))
receiver = RevocationStore(StoredDict({}, None))
for insert in test["inserts"]:
secret = bytes.fromhex(insert["secret"])
@@ -497,7 +497,7 @@ class TestLNUtil(ElectrumTestCase):
def test_shachain_produce_consume(self):
seed = bitcoin.sha256(b"shachaintest")
consumer = RevocationStore(StoredDict({}, None, []))
consumer = RevocationStore(StoredDict({}, None))
for i in range(10000):
secret = get_per_commitment_secret_from_seed(seed, RevocationStore.START_INDEX - i)
try:
@@ -507,7 +507,7 @@ class TestLNUtil(ElectrumTestCase):
if i % 1000 == 0:
c1 = consumer
s1 = json.dumps(c1.storage, cls=MyEncoder)
c2 = RevocationStore(StoredDict(json.loads(s1), None, []))
c2 = RevocationStore(StoredDict(json.loads(s1), None))
s2 = json.dumps(c2.storage, cls=MyEncoder)
self.assertEqual(s1, s2)