StoredDict: use pointers instead of path
Instead of storing its own path, each StoredDict element stores its own key and a pointer to its parent. If a dict is removed from the db, its parent pointer is set to None. This makes self.path return None for all branches that have been pruned. This passes tests/tests_json_db.py and fixes issue #10000
This commit is contained in:
@@ -108,84 +108,111 @@ def key_path(path: Sequence[Union[str, int]], key: Optional[str]) -> str:
|
|||||||
items = [to_str(x) for x in path]
|
items = [to_str(x) for x in path]
|
||||||
if key is not None:
|
if key is not None:
|
||||||
items.append(to_str(key))
|
items.append(to_str(key))
|
||||||
return '/' + '/'.join(items)
|
return '/'.join(items)
|
||||||
|
|
||||||
|
class BaseStoredObject:
|
||||||
|
|
||||||
|
_db: 'JsonDB' = None
|
||||||
|
_key = None
|
||||||
|
_parent = None
|
||||||
|
_lock = None
|
||||||
|
|
||||||
|
def set_db(self, db):
|
||||||
|
self._db = db
|
||||||
|
self._lock = self._db.lock if self._db else threading.RLock()
|
||||||
|
|
||||||
|
def set_parent(self, key, parent):
|
||||||
|
self._key = key
|
||||||
|
self._parent = parent
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lock(self):
|
||||||
|
return self._lock
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self) -> Sequence[str]:
|
||||||
|
# return None iff we are pruned from root
|
||||||
|
x = self
|
||||||
|
s = [x._key]
|
||||||
|
while x._parent is not None:
|
||||||
|
x = x._parent
|
||||||
|
s = [x._key] + s
|
||||||
|
if x._key != '':
|
||||||
|
return None
|
||||||
|
assert self._db is not None
|
||||||
|
return s
|
||||||
|
|
||||||
|
def db_add(self, key, value):
|
||||||
|
if self.path:
|
||||||
|
self._db.add(self.path, key, value)
|
||||||
|
|
||||||
|
def db_replace(self, key, value):
|
||||||
|
if self.path:
|
||||||
|
self._db.replace(self.path, key, value)
|
||||||
|
|
||||||
|
def db_remove(self, key):
|
||||||
|
if self.path:
|
||||||
|
self._db.remove(self.path, key)
|
||||||
|
|
||||||
|
|
||||||
class StoredObject:
|
class StoredObject(BaseStoredObject):
|
||||||
|
"""for attr.s objects """
|
||||||
db: 'JsonDB' = None
|
|
||||||
path = None
|
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
if self.db and key not in ['path', 'db'] and not key.startswith('_'):
|
if self.path and not key.startswith('_'):
|
||||||
if value != getattr(self, key):
|
if value != getattr(self, key):
|
||||||
self.db.add_patch({'op': 'replace', 'path': key_path(self.path, key), 'value': value})
|
self.db_replace(key, value)
|
||||||
object.__setattr__(self, key, value)
|
object.__setattr__(self, key, value)
|
||||||
|
|
||||||
def set_db(self, db, path):
|
|
||||||
self.db = db
|
|
||||||
self.path = path
|
|
||||||
|
|
||||||
def to_json(self):
|
def to_json(self):
|
||||||
d = dict(vars(self))
|
d = dict(vars(self))
|
||||||
d.pop('db', None)
|
|
||||||
d.pop('path', None)
|
|
||||||
# don't expose/store private stuff
|
# don't expose/store private stuff
|
||||||
d = {k: v for k, v in d.items()
|
d = {k: v for k, v in d.items()
|
||||||
if not k.startswith('_')}
|
if not k.startswith('_')}
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_RaiseKeyError = object() # singleton for no-default behavior
|
_RaiseKeyError = object() # singleton for no-default behavior
|
||||||
|
|
||||||
class StoredDict(dict):
|
|
||||||
|
|
||||||
def __init__(self, data, db: 'JsonDB', path):
|
class StoredDict(dict, BaseStoredObject):
|
||||||
self.db = db
|
|
||||||
self.lock = self.db.lock if self.db else threading.RLock()
|
def __init__(self, data: dict, db: 'JsonDB'):
|
||||||
self.path = path
|
self.set_db(db)
|
||||||
# recursively convert dicts to StoredDict
|
# recursively convert dicts to StoredDict
|
||||||
for k, v in list(data.items()):
|
for k, v in list(data.items()):
|
||||||
self.__setitem__(k, v, patch=False)
|
self.__setitem__(k, v)
|
||||||
|
|
||||||
@locked
|
@locked
|
||||||
def __setitem__(self, key, v, patch=True):
|
def __setitem__(self, key, v):
|
||||||
is_new = key not in self
|
is_new = key not in self
|
||||||
# early return to prevent unnecessary disk writes
|
# early return to prevent unnecessary disk writes
|
||||||
if not is_new and patch:
|
if not is_new and self._db and json.dumps(v, cls=self._db.encoder) == json.dumps(self[key], cls=self._db.encoder):
|
||||||
if self.db and json.dumps(v, cls=self.db.encoder) == json.dumps(self[key], cls=self.db.encoder):
|
return
|
||||||
return
|
# convert dict to StoredDict.
|
||||||
# recursively set db and path
|
if type(v) == dict and (self._db is None or self._db._should_convert_to_stored_dict(key)):
|
||||||
if isinstance(v, StoredDict):
|
v = StoredDict(v, self._db)
|
||||||
#assert v.db is None
|
# convert list to StoredList
|
||||||
v.db = self.db
|
elif type(v) == list:
|
||||||
v.path = self.path + [key]
|
v = StoredList(v, self._db)
|
||||||
for k, vv in v.items():
|
|
||||||
v.__setitem__(k, vv, patch=False)
|
|
||||||
# recursively convert dict to StoredDict.
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
if not self.db or self.db._should_convert_to_stored_dict(key):
|
|
||||||
v = StoredDict(v, self.db, self.path + [key])
|
|
||||||
# set parent of StoredObject
|
|
||||||
if isinstance(v, StoredObject):
|
|
||||||
v.set_db(self.db, self.path + [key])
|
|
||||||
# convert lists
|
|
||||||
if isinstance(v, list):
|
|
||||||
v = StoredList(v, self.db, self.path + [key])
|
|
||||||
# reject sets. they do not work well with jsonpatch
|
# reject sets. they do not work well with jsonpatch
|
||||||
if isinstance(v, set):
|
elif isinstance(v, set):
|
||||||
raise Exception(f"Do not store sets inside jsondb. path={self.path!r}")
|
raise Exception(f"Do not store sets inside jsondb. path={self.path!r}")
|
||||||
|
# set db for StoredObject, because it is not set in the constructor
|
||||||
|
if isinstance(v, StoredObject):
|
||||||
|
v.set_db(self._db)
|
||||||
|
# set parent
|
||||||
|
if isinstance(v, BaseStoredObject):
|
||||||
|
v.set_parent(key, self)
|
||||||
# set item
|
# set item
|
||||||
dict.__setitem__(self, key, v)
|
dict.__setitem__(self, key, v)
|
||||||
if self.db and patch:
|
self.db_add(key, v) if is_new else self.db_replace(key, v)
|
||||||
op = 'add' if is_new else 'replace'
|
|
||||||
self.db.add_patch({'op': op, 'path': key_path(self.path, key), 'value': v})
|
|
||||||
|
|
||||||
@locked
|
@locked
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
dict.__delitem__(self, key)
|
dict.__delitem__(self, key)
|
||||||
if self.db:
|
self.db_remove(key)
|
||||||
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
|
|
||||||
|
|
||||||
@locked
|
@locked
|
||||||
def pop(self, key, v=_RaiseKeyError):
|
def pop(self, key, v=_RaiseKeyError):
|
||||||
@@ -195,8 +222,9 @@ class StoredDict(dict):
|
|||||||
else:
|
else:
|
||||||
return v
|
return v
|
||||||
r = dict.pop(self, key)
|
r = dict.pop(self, key)
|
||||||
if self.db:
|
self.db_remove(key)
|
||||||
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
|
if isinstance(r, StoredDict):
|
||||||
|
r._parent = None
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def setdefault(self, key, default = None, /):
|
def setdefault(self, key, default = None, /):
|
||||||
@@ -205,33 +233,28 @@ class StoredDict(dict):
|
|||||||
return self[key]
|
return self[key]
|
||||||
|
|
||||||
|
|
||||||
class StoredList(list):
|
class StoredList(list, BaseStoredObject):
|
||||||
|
|
||||||
def __init__(self, data, db: 'JsonDB', path):
|
def __init__(self, data, db: 'JsonDB'):
|
||||||
list.__init__(self, data)
|
list.__init__(self, data)
|
||||||
self.db = db
|
self.set_db(db)
|
||||||
self.lock = self.db.lock if self.db else threading.RLock()
|
|
||||||
self.path = path
|
|
||||||
|
|
||||||
@locked
|
@locked
|
||||||
def append(self, item):
|
def append(self, item):
|
||||||
n = len(self)
|
n = len(self)
|
||||||
list.append(self, item)
|
list.append(self, item)
|
||||||
if self.db:
|
self.db_add('%d'%n, item)
|
||||||
self.db.add_patch({'op': 'add', 'path': key_path(self.path, '%d'%n), 'value':item})
|
|
||||||
|
|
||||||
@locked
|
@locked
|
||||||
def remove(self, item):
|
def remove(self, item):
|
||||||
n = self.index(item)
|
n = self.index(item)
|
||||||
list.remove(self, item)
|
list.remove(self, item)
|
||||||
if self.db:
|
self.db_remove('%d'%n)
|
||||||
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, '%d'%n)})
|
|
||||||
|
|
||||||
@locked
|
@locked
|
||||||
def clear(self):
|
def clear(self):
|
||||||
list.clear(self)
|
list.clear(self)
|
||||||
if self.db:
|
self.db_replace(None, [])
|
||||||
self.db.add_patch({'op': 'replace', 'path': key_path(self.path, None), 'value':[]})
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -259,7 +282,8 @@ class JsonDB(Logger):
|
|||||||
# convert json to python objects
|
# convert json to python objects
|
||||||
data = self._convert_dict([], data)
|
data = self._convert_dict([], data)
|
||||||
# convert dict to StoredDict
|
# convert dict to StoredDict
|
||||||
self.data = StoredDict(data, self, [])
|
self.data = StoredDict(data, self)
|
||||||
|
self.data.set_parent('', None)
|
||||||
# write file in case there was a db upgrade
|
# write file in case there was a db upgrade
|
||||||
if self.storage and self.storage.file_exists():
|
if self.storage and self.storage.file_exists():
|
||||||
self.write_and_force_consolidation()
|
self.write_and_force_consolidation()
|
||||||
@@ -333,6 +357,15 @@ class JsonDB(Logger):
|
|||||||
self.pending_changes.append(json.dumps(patch, cls=self.encoder))
|
self.pending_changes.append(json.dumps(patch, cls=self.encoder))
|
||||||
self.set_modified(True)
|
self.set_modified(True)
|
||||||
|
|
||||||
|
def add(self, path, key, value):
|
||||||
|
self.add_patch({'op': 'add', 'path': key_path(path, key), 'value': value})
|
||||||
|
|
||||||
|
def replace(self, path, key, value):
|
||||||
|
self.add_patch({'op': 'replace', 'path': key_path(path, key), 'value': value})
|
||||||
|
|
||||||
|
def remove(self, path, key):
|
||||||
|
self.add_patch({'op': 'remove', 'path': key_path(path, key)})
|
||||||
|
|
||||||
@locked
|
@locked
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
v = self.data.get(key)
|
v = self.data.get(key)
|
||||||
|
|||||||
@@ -773,7 +773,7 @@ class Channel(AbstractChannel):
|
|||||||
Logger.__init__(self) # should be after short_channel_id is set
|
Logger.__init__(self) # should be after short_channel_id is set
|
||||||
self.lnworker = lnworker
|
self.lnworker = lnworker
|
||||||
self.storage = state
|
self.storage = state
|
||||||
self.db_lock = self.storage.db.lock if self.storage.db else threading.RLock()
|
self.db_lock = self.storage.lock
|
||||||
self.config = {}
|
self.config = {}
|
||||||
self.config[LOCAL] = state["local_config"]
|
self.config[LOCAL] = state["local_config"]
|
||||||
self.config[REMOTE] = state["remote_config"]
|
self.config[REMOTE] = state["remote_config"]
|
||||||
|
|||||||
@@ -1263,8 +1263,7 @@ class Peer(Logger, EventListener):
|
|||||||
"revocation_store": {},
|
"revocation_store": {},
|
||||||
"channel_type": channel_type,
|
"channel_type": channel_type,
|
||||||
}
|
}
|
||||||
# set db to None, because we do not want to write updates until channel is saved
|
return StoredDict(chan_dict, self.lnworker.db)
|
||||||
return StoredDict(chan_dict, None, [])
|
|
||||||
|
|
||||||
@non_blocking_msg_handler
|
@non_blocking_msg_handler
|
||||||
async def on_open_channel(self, payload):
|
async def on_open_channel(self, payload):
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
|
|||||||
'revocation_store': {},
|
'revocation_store': {},
|
||||||
'channel_type': channel_type,
|
'channel_type': channel_type,
|
||||||
}
|
}
|
||||||
return StoredDict(state, None, [])
|
return StoredDict(state, None)
|
||||||
|
|
||||||
|
|
||||||
def bip32(sequence):
|
def bip32(sequence):
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ class H(NamedTuple):
|
|||||||
|
|
||||||
class TestHTLCManager(ElectrumTestCase):
|
class TestHTLCManager(ElectrumTestCase):
|
||||||
def test_adding_htlcs_race(self):
|
def test_adding_htlcs_race(self):
|
||||||
A = HTLCManager(StoredDict({}, None, []))
|
A = HTLCManager(StoredDict({}, None))
|
||||||
B = HTLCManager(StoredDict({}, None, []))
|
B = HTLCManager(StoredDict({}, None))
|
||||||
A.channel_open_finished()
|
A.channel_open_finished()
|
||||||
B.channel_open_finished()
|
B.channel_open_finished()
|
||||||
ah0, bh0 = H('A', 0), H('B', 0)
|
ah0, bh0 = H('A', 0), H('B', 0)
|
||||||
@@ -61,8 +61,8 @@ class TestHTLCManager(ElectrumTestCase):
|
|||||||
|
|
||||||
def test_single_htlc_full_lifecycle(self):
|
def test_single_htlc_full_lifecycle(self):
|
||||||
def htlc_lifecycle(htlc_success: bool):
|
def htlc_lifecycle(htlc_success: bool):
|
||||||
A = HTLCManager(StoredDict({}, None, []))
|
A = HTLCManager(StoredDict({}, None))
|
||||||
B = HTLCManager(StoredDict({}, None, []))
|
B = HTLCManager(StoredDict({}, None))
|
||||||
A.channel_open_finished()
|
A.channel_open_finished()
|
||||||
B.channel_open_finished()
|
B.channel_open_finished()
|
||||||
B.recv_htlc(A.send_htlc(H('A', 0)))
|
B.recv_htlc(A.send_htlc(H('A', 0)))
|
||||||
@@ -134,8 +134,8 @@ class TestHTLCManager(ElectrumTestCase):
|
|||||||
|
|
||||||
def test_remove_htlc_while_owing_commitment(self):
|
def test_remove_htlc_while_owing_commitment(self):
|
||||||
def htlc_lifecycle(htlc_success: bool):
|
def htlc_lifecycle(htlc_success: bool):
|
||||||
A = HTLCManager(StoredDict({}, None, []))
|
A = HTLCManager(StoredDict({}, None))
|
||||||
B = HTLCManager(StoredDict({}, None, []))
|
B = HTLCManager(StoredDict({}, None))
|
||||||
A.channel_open_finished()
|
A.channel_open_finished()
|
||||||
B.channel_open_finished()
|
B.channel_open_finished()
|
||||||
ah0 = H('A', 0)
|
ah0 = H('A', 0)
|
||||||
@@ -171,8 +171,8 @@ class TestHTLCManager(ElectrumTestCase):
|
|||||||
htlc_lifecycle(htlc_success=False)
|
htlc_lifecycle(htlc_success=False)
|
||||||
|
|
||||||
def test_adding_htlc_between_send_ctx_and_recv_rev(self):
|
def test_adding_htlc_between_send_ctx_and_recv_rev(self):
|
||||||
A = HTLCManager(StoredDict({}, None, []))
|
A = HTLCManager(StoredDict({}, None))
|
||||||
B = HTLCManager(StoredDict({}, None, []))
|
B = HTLCManager(StoredDict({}, None))
|
||||||
A.channel_open_finished()
|
A.channel_open_finished()
|
||||||
B.channel_open_finished()
|
B.channel_open_finished()
|
||||||
A.send_ctx()
|
A.send_ctx()
|
||||||
@@ -217,8 +217,8 @@ class TestHTLCManager(ElectrumTestCase):
|
|||||||
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
|
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
|
||||||
|
|
||||||
def test_unacked_local_updates(self):
|
def test_unacked_local_updates(self):
|
||||||
A = HTLCManager(StoredDict({}, None, []))
|
A = HTLCManager(StoredDict({}, None))
|
||||||
B = HTLCManager(StoredDict({}, None, []))
|
B = HTLCManager(StoredDict({}, None))
|
||||||
A.channel_open_finished()
|
A.channel_open_finished()
|
||||||
B.channel_open_finished()
|
B.channel_open_finished()
|
||||||
self.assertEqual({}, A.get_unacked_local_updates())
|
self.assertEqual({}, A.get_unacked_local_updates())
|
||||||
|
|||||||
@@ -474,7 +474,7 @@ class TestLNUtil(ElectrumTestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
for test in tests:
|
for test in tests:
|
||||||
receiver = RevocationStore(StoredDict({}, None, []))
|
receiver = RevocationStore(StoredDict({}, None))
|
||||||
for insert in test["inserts"]:
|
for insert in test["inserts"]:
|
||||||
secret = bytes.fromhex(insert["secret"])
|
secret = bytes.fromhex(insert["secret"])
|
||||||
|
|
||||||
@@ -497,7 +497,7 @@ class TestLNUtil(ElectrumTestCase):
|
|||||||
|
|
||||||
def test_shachain_produce_consume(self):
|
def test_shachain_produce_consume(self):
|
||||||
seed = bitcoin.sha256(b"shachaintest")
|
seed = bitcoin.sha256(b"shachaintest")
|
||||||
consumer = RevocationStore(StoredDict({}, None, []))
|
consumer = RevocationStore(StoredDict({}, None))
|
||||||
for i in range(10000):
|
for i in range(10000):
|
||||||
secret = get_per_commitment_secret_from_seed(seed, RevocationStore.START_INDEX - i)
|
secret = get_per_commitment_secret_from_seed(seed, RevocationStore.START_INDEX - i)
|
||||||
try:
|
try:
|
||||||
@@ -507,7 +507,7 @@ class TestLNUtil(ElectrumTestCase):
|
|||||||
if i % 1000 == 0:
|
if i % 1000 == 0:
|
||||||
c1 = consumer
|
c1 = consumer
|
||||||
s1 = json.dumps(c1.storage, cls=MyEncoder)
|
s1 = json.dumps(c1.storage, cls=MyEncoder)
|
||||||
c2 = RevocationStore(StoredDict(json.loads(s1), None, []))
|
c2 = RevocationStore(StoredDict(json.loads(s1), None))
|
||||||
s2 = json.dumps(c2.storage, cls=MyEncoder)
|
s2 = json.dumps(c2.storage, cls=MyEncoder)
|
||||||
self.assertEqual(s1, s2)
|
self.assertEqual(s1, s2)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user