StoredDict: use pointers instead of path
Instead of storing its own path, each StoredDict element stores its own key and a pointer to its parent. If a dict is removed from the db, its parent pointer is set to None. This makes self.path return None for all branches that have been pruned. This passes tests/tests_json_db.py and fixes issue #10000
This commit is contained in:
@@ -108,84 +108,111 @@ def key_path(path: Sequence[Union[str, int]], key: Optional[str]) -> str:
|
||||
items = [to_str(x) for x in path]
|
||||
if key is not None:
|
||||
items.append(to_str(key))
|
||||
return '/' + '/'.join(items)
|
||||
return '/'.join(items)
|
||||
|
||||
class BaseStoredObject:
|
||||
|
||||
_db: 'JsonDB' = None
|
||||
_key = None
|
||||
_parent = None
|
||||
_lock = None
|
||||
|
||||
def set_db(self, db):
|
||||
self._db = db
|
||||
self._lock = self._db.lock if self._db else threading.RLock()
|
||||
|
||||
def set_parent(self, key, parent):
|
||||
self._key = key
|
||||
self._parent = parent
|
||||
|
||||
@property
|
||||
def lock(self):
|
||||
return self._lock
|
||||
|
||||
@property
|
||||
def path(self) -> Sequence[str]:
|
||||
# return None iff we are pruned from root
|
||||
x = self
|
||||
s = [x._key]
|
||||
while x._parent is not None:
|
||||
x = x._parent
|
||||
s = [x._key] + s
|
||||
if x._key != '':
|
||||
return None
|
||||
assert self._db is not None
|
||||
return s
|
||||
|
||||
def db_add(self, key, value):
|
||||
if self.path:
|
||||
self._db.add(self.path, key, value)
|
||||
|
||||
def db_replace(self, key, value):
|
||||
if self.path:
|
||||
self._db.replace(self.path, key, value)
|
||||
|
||||
def db_remove(self, key):
|
||||
if self.path:
|
||||
self._db.remove(self.path, key)
|
||||
|
||||
|
||||
class StoredObject:
|
||||
|
||||
db: 'JsonDB' = None
|
||||
path = None
|
||||
class StoredObject(BaseStoredObject):
|
||||
"""for attr.s objects """
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if self.db and key not in ['path', 'db'] and not key.startswith('_'):
|
||||
if self.path and not key.startswith('_'):
|
||||
if value != getattr(self, key):
|
||||
self.db.add_patch({'op': 'replace', 'path': key_path(self.path, key), 'value': value})
|
||||
self.db_replace(key, value)
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
def set_db(self, db, path):
|
||||
self.db = db
|
||||
self.path = path
|
||||
|
||||
def to_json(self):
|
||||
d = dict(vars(self))
|
||||
d.pop('db', None)
|
||||
d.pop('path', None)
|
||||
# don't expose/store private stuff
|
||||
d = {k: v for k, v in d.items()
|
||||
if not k.startswith('_')}
|
||||
return d
|
||||
|
||||
|
||||
|
||||
_RaiseKeyError = object() # singleton for no-default behavior
|
||||
|
||||
class StoredDict(dict):
|
||||
|
||||
def __init__(self, data, db: 'JsonDB', path):
|
||||
self.db = db
|
||||
self.lock = self.db.lock if self.db else threading.RLock()
|
||||
self.path = path
|
||||
class StoredDict(dict, BaseStoredObject):
|
||||
|
||||
def __init__(self, data: dict, db: 'JsonDB'):
|
||||
self.set_db(db)
|
||||
# recursively convert dicts to StoredDict
|
||||
for k, v in list(data.items()):
|
||||
self.__setitem__(k, v, patch=False)
|
||||
self.__setitem__(k, v)
|
||||
|
||||
@locked
|
||||
def __setitem__(self, key, v, patch=True):
|
||||
def __setitem__(self, key, v):
|
||||
is_new = key not in self
|
||||
# early return to prevent unnecessary disk writes
|
||||
if not is_new and patch:
|
||||
if self.db and json.dumps(v, cls=self.db.encoder) == json.dumps(self[key], cls=self.db.encoder):
|
||||
return
|
||||
# recursively set db and path
|
||||
if isinstance(v, StoredDict):
|
||||
#assert v.db is None
|
||||
v.db = self.db
|
||||
v.path = self.path + [key]
|
||||
for k, vv in v.items():
|
||||
v.__setitem__(k, vv, patch=False)
|
||||
# recursively convert dict to StoredDict.
|
||||
elif isinstance(v, dict):
|
||||
if not self.db or self.db._should_convert_to_stored_dict(key):
|
||||
v = StoredDict(v, self.db, self.path + [key])
|
||||
# set parent of StoredObject
|
||||
if isinstance(v, StoredObject):
|
||||
v.set_db(self.db, self.path + [key])
|
||||
# convert lists
|
||||
if isinstance(v, list):
|
||||
v = StoredList(v, self.db, self.path + [key])
|
||||
if not is_new and self._db and json.dumps(v, cls=self._db.encoder) == json.dumps(self[key], cls=self._db.encoder):
|
||||
return
|
||||
# convert dict to StoredDict.
|
||||
if type(v) == dict and (self._db is None or self._db._should_convert_to_stored_dict(key)):
|
||||
v = StoredDict(v, self._db)
|
||||
# convert list to StoredList
|
||||
elif type(v) == list:
|
||||
v = StoredList(v, self._db)
|
||||
# reject sets. they do not work well with jsonpatch
|
||||
if isinstance(v, set):
|
||||
elif isinstance(v, set):
|
||||
raise Exception(f"Do not store sets inside jsondb. path={self.path!r}")
|
||||
# set db for StoredObject, because it is not set in the constructor
|
||||
if isinstance(v, StoredObject):
|
||||
v.set_db(self._db)
|
||||
# set parent
|
||||
if isinstance(v, BaseStoredObject):
|
||||
v.set_parent(key, self)
|
||||
# set item
|
||||
dict.__setitem__(self, key, v)
|
||||
if self.db and patch:
|
||||
op = 'add' if is_new else 'replace'
|
||||
self.db.add_patch({'op': op, 'path': key_path(self.path, key), 'value': v})
|
||||
self.db_add(key, v) if is_new else self.db_replace(key, v)
|
||||
|
||||
@locked
|
||||
def __delitem__(self, key):
|
||||
dict.__delitem__(self, key)
|
||||
if self.db:
|
||||
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
|
||||
self.db_remove(key)
|
||||
|
||||
@locked
|
||||
def pop(self, key, v=_RaiseKeyError):
|
||||
@@ -195,8 +222,9 @@ class StoredDict(dict):
|
||||
else:
|
||||
return v
|
||||
r = dict.pop(self, key)
|
||||
if self.db:
|
||||
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
|
||||
self.db_remove(key)
|
||||
if isinstance(r, StoredDict):
|
||||
r._parent = None
|
||||
return r
|
||||
|
||||
def setdefault(self, key, default = None, /):
|
||||
@@ -205,33 +233,28 @@ class StoredDict(dict):
|
||||
return self[key]
|
||||
|
||||
|
||||
class StoredList(list):
|
||||
class StoredList(list, BaseStoredObject):
|
||||
|
||||
def __init__(self, data, db: 'JsonDB', path):
|
||||
def __init__(self, data, db: 'JsonDB'):
|
||||
list.__init__(self, data)
|
||||
self.db = db
|
||||
self.lock = self.db.lock if self.db else threading.RLock()
|
||||
self.path = path
|
||||
self.set_db(db)
|
||||
|
||||
@locked
|
||||
def append(self, item):
|
||||
n = len(self)
|
||||
list.append(self, item)
|
||||
if self.db:
|
||||
self.db.add_patch({'op': 'add', 'path': key_path(self.path, '%d'%n), 'value':item})
|
||||
self.db_add('%d'%n, item)
|
||||
|
||||
@locked
|
||||
def remove(self, item):
|
||||
n = self.index(item)
|
||||
list.remove(self, item)
|
||||
if self.db:
|
||||
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, '%d'%n)})
|
||||
self.db_remove('%d'%n)
|
||||
|
||||
@locked
|
||||
def clear(self):
|
||||
list.clear(self)
|
||||
if self.db:
|
||||
self.db.add_patch({'op': 'replace', 'path': key_path(self.path, None), 'value':[]})
|
||||
self.db_replace(None, [])
|
||||
|
||||
|
||||
|
||||
@@ -259,7 +282,8 @@ class JsonDB(Logger):
|
||||
# convert json to python objects
|
||||
data = self._convert_dict([], data)
|
||||
# convert dict to StoredDict
|
||||
self.data = StoredDict(data, self, [])
|
||||
self.data = StoredDict(data, self)
|
||||
self.data.set_parent('', None)
|
||||
# write file in case there was a db upgrade
|
||||
if self.storage and self.storage.file_exists():
|
||||
self.write_and_force_consolidation()
|
||||
@@ -333,6 +357,15 @@ class JsonDB(Logger):
|
||||
self.pending_changes.append(json.dumps(patch, cls=self.encoder))
|
||||
self.set_modified(True)
|
||||
|
||||
def add(self, path, key, value):
|
||||
self.add_patch({'op': 'add', 'path': key_path(path, key), 'value': value})
|
||||
|
||||
def replace(self, path, key, value):
|
||||
self.add_patch({'op': 'replace', 'path': key_path(path, key), 'value': value})
|
||||
|
||||
def remove(self, path, key):
|
||||
self.add_patch({'op': 'remove', 'path': key_path(path, key)})
|
||||
|
||||
@locked
|
||||
def get(self, key, default=None):
|
||||
v = self.data.get(key)
|
||||
|
||||
@@ -773,7 +773,7 @@ class Channel(AbstractChannel):
|
||||
Logger.__init__(self) # should be after short_channel_id is set
|
||||
self.lnworker = lnworker
|
||||
self.storage = state
|
||||
self.db_lock = self.storage.db.lock if self.storage.db else threading.RLock()
|
||||
self.db_lock = self.storage.lock
|
||||
self.config = {}
|
||||
self.config[LOCAL] = state["local_config"]
|
||||
self.config[REMOTE] = state["remote_config"]
|
||||
|
||||
@@ -1263,8 +1263,7 @@ class Peer(Logger, EventListener):
|
||||
"revocation_store": {},
|
||||
"channel_type": channel_type,
|
||||
}
|
||||
# set db to None, because we do not want to write updates until channel is saved
|
||||
return StoredDict(chan_dict, None, [])
|
||||
return StoredDict(chan_dict, self.lnworker.db)
|
||||
|
||||
@non_blocking_msg_handler
|
||||
async def on_open_channel(self, payload):
|
||||
|
||||
@@ -118,7 +118,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
|
||||
'revocation_store': {},
|
||||
'channel_type': channel_type,
|
||||
}
|
||||
return StoredDict(state, None, [])
|
||||
return StoredDict(state, None)
|
||||
|
||||
|
||||
def bip32(sequence):
|
||||
|
||||
@@ -14,8 +14,8 @@ class H(NamedTuple):
|
||||
|
||||
class TestHTLCManager(ElectrumTestCase):
|
||||
def test_adding_htlcs_race(self):
|
||||
A = HTLCManager(StoredDict({}, None, []))
|
||||
B = HTLCManager(StoredDict({}, None, []))
|
||||
A = HTLCManager(StoredDict({}, None))
|
||||
B = HTLCManager(StoredDict({}, None))
|
||||
A.channel_open_finished()
|
||||
B.channel_open_finished()
|
||||
ah0, bh0 = H('A', 0), H('B', 0)
|
||||
@@ -61,8 +61,8 @@ class TestHTLCManager(ElectrumTestCase):
|
||||
|
||||
def test_single_htlc_full_lifecycle(self):
|
||||
def htlc_lifecycle(htlc_success: bool):
|
||||
A = HTLCManager(StoredDict({}, None, []))
|
||||
B = HTLCManager(StoredDict({}, None, []))
|
||||
A = HTLCManager(StoredDict({}, None))
|
||||
B = HTLCManager(StoredDict({}, None))
|
||||
A.channel_open_finished()
|
||||
B.channel_open_finished()
|
||||
B.recv_htlc(A.send_htlc(H('A', 0)))
|
||||
@@ -134,8 +134,8 @@ class TestHTLCManager(ElectrumTestCase):
|
||||
|
||||
def test_remove_htlc_while_owing_commitment(self):
|
||||
def htlc_lifecycle(htlc_success: bool):
|
||||
A = HTLCManager(StoredDict({}, None, []))
|
||||
B = HTLCManager(StoredDict({}, None, []))
|
||||
A = HTLCManager(StoredDict({}, None))
|
||||
B = HTLCManager(StoredDict({}, None))
|
||||
A.channel_open_finished()
|
||||
B.channel_open_finished()
|
||||
ah0 = H('A', 0)
|
||||
@@ -171,8 +171,8 @@ class TestHTLCManager(ElectrumTestCase):
|
||||
htlc_lifecycle(htlc_success=False)
|
||||
|
||||
def test_adding_htlc_between_send_ctx_and_recv_rev(self):
|
||||
A = HTLCManager(StoredDict({}, None, []))
|
||||
B = HTLCManager(StoredDict({}, None, []))
|
||||
A = HTLCManager(StoredDict({}, None))
|
||||
B = HTLCManager(StoredDict({}, None))
|
||||
A.channel_open_finished()
|
||||
B.channel_open_finished()
|
||||
A.send_ctx()
|
||||
@@ -217,8 +217,8 @@ class TestHTLCManager(ElectrumTestCase):
|
||||
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
|
||||
|
||||
def test_unacked_local_updates(self):
|
||||
A = HTLCManager(StoredDict({}, None, []))
|
||||
B = HTLCManager(StoredDict({}, None, []))
|
||||
A = HTLCManager(StoredDict({}, None))
|
||||
B = HTLCManager(StoredDict({}, None))
|
||||
A.channel_open_finished()
|
||||
B.channel_open_finished()
|
||||
self.assertEqual({}, A.get_unacked_local_updates())
|
||||
|
||||
@@ -474,7 +474,7 @@ class TestLNUtil(ElectrumTestCase):
|
||||
]
|
||||
|
||||
for test in tests:
|
||||
receiver = RevocationStore(StoredDict({}, None, []))
|
||||
receiver = RevocationStore(StoredDict({}, None))
|
||||
for insert in test["inserts"]:
|
||||
secret = bytes.fromhex(insert["secret"])
|
||||
|
||||
@@ -497,7 +497,7 @@ class TestLNUtil(ElectrumTestCase):
|
||||
|
||||
def test_shachain_produce_consume(self):
|
||||
seed = bitcoin.sha256(b"shachaintest")
|
||||
consumer = RevocationStore(StoredDict({}, None, []))
|
||||
consumer = RevocationStore(StoredDict({}, None))
|
||||
for i in range(10000):
|
||||
secret = get_per_commitment_secret_from_seed(seed, RevocationStore.START_INDEX - i)
|
||||
try:
|
||||
@@ -507,7 +507,7 @@ class TestLNUtil(ElectrumTestCase):
|
||||
if i % 1000 == 0:
|
||||
c1 = consumer
|
||||
s1 = json.dumps(c1.storage, cls=MyEncoder)
|
||||
c2 = RevocationStore(StoredDict(json.loads(s1), None, []))
|
||||
c2 = RevocationStore(StoredDict(json.loads(s1), None))
|
||||
s2 = json.dumps(c2.storage, cls=MyEncoder)
|
||||
self.assertEqual(s1, s2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user