partial-writes using jsonpatch
- partial writes are append only. - StoredDict objects will append partial writes to the wallet file when items are added, replaced, removed. - Lists in the wallet file that have not been registered as StoredObject are converted to StoredList, which overloads append() and remove(). Those methods too will append partial writes to the wallet file. - Unlike the old jsonpatch branch, this branch does not support file encryption. Encrypted files always fully rewritten, even if the change before encryption is a partial write.
This commit is contained in:
@@ -7,6 +7,7 @@ aiohttp_socks>=0.3
|
|||||||
certifi
|
certifi
|
||||||
bitstring
|
bitstring
|
||||||
attrs>=20.1.0
|
attrs>=20.1.0
|
||||||
|
jsonpatch
|
||||||
|
|
||||||
# Note that we also need the dnspython[DNSSEC] extra which pulls in cryptography,
|
# Note that we also need the dnspython[DNSSEC] extra which pulls in cryptography,
|
||||||
# but as that is not pure-python it cannot be listed in this file!
|
# but as that is not pure-python it cannot be listed in this file!
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import threading
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
import jsonpatch
|
||||||
|
|
||||||
from . import util
|
from . import util
|
||||||
from .util import WalletFileException, profiler
|
from .util import WalletFileException, profiler
|
||||||
@@ -80,22 +81,35 @@ def stored_in(name, _type=dict):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def key_path(path, key):
|
||||||
|
def to_str(x):
|
||||||
|
if isinstance(x, int):
|
||||||
|
return str(int(x))
|
||||||
|
else:
|
||||||
|
assert isinstance(x, str)
|
||||||
|
return x
|
||||||
|
return '/' + '/'.join([to_str(x) for x in path + [to_str(key)]])
|
||||||
|
|
||||||
|
|
||||||
class StoredObject:
|
class StoredObject:
|
||||||
|
|
||||||
db = None
|
db = None
|
||||||
|
path = None
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
if self.db:
|
if self.db and key not in ['path', 'db'] and not key.startswith('_'):
|
||||||
self.db.set_modified(True)
|
if value != getattr(self, key):
|
||||||
|
self.db.add_patch({'op': 'replace', 'path': key_path(self.path, key), 'value': value})
|
||||||
object.__setattr__(self, key, value)
|
object.__setattr__(self, key, value)
|
||||||
|
|
||||||
def set_db(self, db):
|
def set_db(self, db, path):
|
||||||
self.db = db
|
self.db = db
|
||||||
|
self.path = path
|
||||||
|
|
||||||
def to_json(self):
|
def to_json(self):
|
||||||
d = dict(vars(self))
|
d = dict(vars(self))
|
||||||
d.pop('db', None)
|
d.pop('db', None)
|
||||||
|
d.pop('path', None)
|
||||||
# don't expose/store private stuff
|
# don't expose/store private stuff
|
||||||
d = {k: v for k, v in d.items()
|
d = {k: v for k, v in d.items()
|
||||||
if not k.startswith('_')}
|
if not k.startswith('_')}
|
||||||
@@ -112,20 +126,22 @@ class StoredDict(dict):
|
|||||||
self.path = path
|
self.path = path
|
||||||
# recursively convert dicts to StoredDict
|
# recursively convert dicts to StoredDict
|
||||||
for k, v in list(data.items()):
|
for k, v in list(data.items()):
|
||||||
self.__setitem__(k, v)
|
self.__setitem__(k, v, patch=False)
|
||||||
|
|
||||||
@locked
|
@locked
|
||||||
def __setitem__(self, key, v):
|
def __setitem__(self, key, v, patch=True):
|
||||||
is_new = key not in self
|
is_new = key not in self
|
||||||
# early return to prevent unnecessary disk writes
|
# early return to prevent unnecessary disk writes
|
||||||
if not is_new and self[key] == v:
|
if not is_new and patch:
|
||||||
return
|
if self.db and json.dumps(v, cls=self.db.encoder) == json.dumps(self[key], cls=self.db.encoder):
|
||||||
|
return
|
||||||
# recursively set db and path
|
# recursively set db and path
|
||||||
if isinstance(v, StoredDict):
|
if isinstance(v, StoredDict):
|
||||||
|
#assert v.db is None
|
||||||
v.db = self.db
|
v.db = self.db
|
||||||
v.path = self.path + [key]
|
v.path = self.path + [key]
|
||||||
for k, vv in v.items():
|
for k, vv in v.items():
|
||||||
v[k] = vv
|
v.__setitem__(k, vv, patch=False)
|
||||||
# recursively convert dict to StoredDict.
|
# recursively convert dict to StoredDict.
|
||||||
# _convert_dict is called breadth-first
|
# _convert_dict is called breadth-first
|
||||||
elif isinstance(v, dict):
|
elif isinstance(v, dict):
|
||||||
@@ -139,29 +155,57 @@ class StoredDict(dict):
|
|||||||
v = self.db._convert_value(self.path, key, v)
|
v = self.db._convert_value(self.path, key, v)
|
||||||
# set parent of StoredObject
|
# set parent of StoredObject
|
||||||
if isinstance(v, StoredObject):
|
if isinstance(v, StoredObject):
|
||||||
v.set_db(self.db)
|
v.set_db(self.db, self.path + [key])
|
||||||
|
# convert lists
|
||||||
|
if isinstance(v, list):
|
||||||
|
v = StoredList(v, self.db, self.path + [key])
|
||||||
# set item
|
# set item
|
||||||
dict.__setitem__(self, key, v)
|
dict.__setitem__(self, key, v)
|
||||||
if self.db:
|
if self.db and patch:
|
||||||
self.db.set_modified(True)
|
op = 'add' if is_new else 'replace'
|
||||||
|
self.db.add_patch({'op': op, 'path': key_path(self.path, key), 'value': v})
|
||||||
|
|
||||||
@locked
|
@locked
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
dict.__delitem__(self, key)
|
dict.__delitem__(self, key)
|
||||||
if self.db:
|
if self.db:
|
||||||
self.db.set_modified(True)
|
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
|
||||||
|
|
||||||
@locked
|
@locked
|
||||||
def pop(self, key, v=_RaiseKeyError):
|
def pop(self, key, v=_RaiseKeyError):
|
||||||
if v is _RaiseKeyError:
|
if key not in self:
|
||||||
r = dict.pop(self, key)
|
if v is _RaiseKeyError:
|
||||||
else:
|
raise KeyError(key)
|
||||||
r = dict.pop(self, key, v)
|
else:
|
||||||
|
return v
|
||||||
|
r = dict.pop(self, key)
|
||||||
if self.db:
|
if self.db:
|
||||||
self.db.set_modified(True)
|
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
class StoredList(list):
|
||||||
|
|
||||||
|
def __init__(self, data, db, path):
|
||||||
|
list.__init__(self, data)
|
||||||
|
self.db = db
|
||||||
|
self.lock = self.db.lock if self.db else threading.RLock()
|
||||||
|
self.path = path
|
||||||
|
|
||||||
|
@locked
|
||||||
|
def append(self, item):
|
||||||
|
n = len(self)
|
||||||
|
list.append(self, item)
|
||||||
|
if self.db:
|
||||||
|
self.db.add_patch({'op': 'add', 'path': key_path(self.path, '%d'%n), 'value':item})
|
||||||
|
|
||||||
|
@locked
|
||||||
|
def remove(self, item):
|
||||||
|
n = self.index(item)
|
||||||
|
list.remove(self, item)
|
||||||
|
if self.db:
|
||||||
|
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, '%d'%n)})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class JsonDB(Logger):
|
class JsonDB(Logger):
|
||||||
@@ -171,34 +215,39 @@ class JsonDB(Logger):
|
|||||||
self.lock = threading.RLock()
|
self.lock = threading.RLock()
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
|
self.pending_changes = []
|
||||||
self._modified = False
|
self._modified = False
|
||||||
# load data
|
# load data
|
||||||
data = self.load_data(s)
|
data = self.load_data(s)
|
||||||
if upgrader:
|
if upgrader:
|
||||||
data, was_upgraded = upgrader(data)
|
data, was_upgraded = upgrader(data)
|
||||||
else:
|
self._modified |= was_upgraded
|
||||||
was_upgraded = False
|
|
||||||
# convert to StoredDict
|
# convert to StoredDict
|
||||||
self.data = StoredDict(data, self, [])
|
self.data = StoredDict(data, self, [])
|
||||||
# note: self._modified may have been affected by StoredDict
|
|
||||||
self._modified = was_upgraded
|
|
||||||
# write file in case there was a db upgrade
|
# write file in case there was a db upgrade
|
||||||
if self.storage and self.storage.file_exists():
|
if self.storage and self.storage.file_exists():
|
||||||
self.write()
|
self._write()
|
||||||
|
|
||||||
def load_data(self, s:str) -> dict:
|
def load_data(self, s:str) -> dict:
|
||||||
""" overloaded in wallet_db """
|
""" overloaded in wallet_db """
|
||||||
if s == '':
|
if s == '':
|
||||||
return {}
|
return {}
|
||||||
try:
|
try:
|
||||||
data = json.loads(s)
|
data = json.loads('[' + s + ']')
|
||||||
|
data, patches = data[0], data[1:]
|
||||||
except Exception:
|
except Exception:
|
||||||
if r := self.maybe_load_ast_data(s):
|
if r := self.maybe_load_ast_data(s):
|
||||||
data = r
|
data, patches = r, []
|
||||||
else:
|
else:
|
||||||
raise WalletFileException("Cannot read wallet file. (parsing failed)")
|
raise WalletFileException("Cannot read wallet file. (parsing failed)")
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
raise WalletFileException("Malformed wallet file (not dict)")
|
raise WalletFileException("Malformed wallet file (not dict)")
|
||||||
|
if patches:
|
||||||
|
# apply patches
|
||||||
|
self.logger.info('found %d patches'%len(patches))
|
||||||
|
patch = jsonpatch.JsonPatch(patches)
|
||||||
|
data = patch.apply(data)
|
||||||
|
self.set_modified(True)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def maybe_load_ast_data(self, s):
|
def maybe_load_ast_data(self, s):
|
||||||
@@ -227,6 +276,11 @@ class JsonDB(Logger):
|
|||||||
def modified(self):
|
def modified(self):
|
||||||
return self._modified
|
return self._modified
|
||||||
|
|
||||||
|
@locked
|
||||||
|
def add_patch(self, patch):
|
||||||
|
self.pending_changes.append(json.dumps(patch, cls=self.encoder))
|
||||||
|
self.set_modified(True)
|
||||||
|
|
||||||
@locked
|
@locked
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
v = self.data.get(key)
|
v = self.data.get(key)
|
||||||
@@ -259,6 +313,12 @@ class JsonDB(Logger):
|
|||||||
self.data[name] = {}
|
self.data[name] = {}
|
||||||
return self.data[name]
|
return self.data[name]
|
||||||
|
|
||||||
|
@locked
|
||||||
|
def get_stored_item(self, key, default) -> dict:
|
||||||
|
if key not in self.data:
|
||||||
|
self.data[key] = default
|
||||||
|
return self.data[key]
|
||||||
|
|
||||||
@locked
|
@locked
|
||||||
def dump(self, *, human_readable: bool = True) -> str:
|
def dump(self, *, human_readable: bool = True) -> str:
|
||||||
"""Serializes the DB as a string.
|
"""Serializes the DB as a string.
|
||||||
@@ -302,10 +362,27 @@ class JsonDB(Logger):
|
|||||||
v = constructor(v)
|
v = constructor(v)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@locked
|
||||||
def write(self):
|
def write(self):
|
||||||
with self.lock:
|
if self.storage.file_exists() and not self.storage.is_encrypted():
|
||||||
|
self._append_pending_changes()
|
||||||
|
else:
|
||||||
self._write()
|
self._write()
|
||||||
|
|
||||||
|
@locked
|
||||||
|
def _append_pending_changes(self):
|
||||||
|
if threading.current_thread().daemon:
|
||||||
|
self.logger.warning('daemon thread cannot write db')
|
||||||
|
return
|
||||||
|
if not self.pending_changes:
|
||||||
|
self.logger.info('no pending changes')
|
||||||
|
return
|
||||||
|
self.logger.info(f'appending {len(self.pending_changes)} pending changes')
|
||||||
|
s = ''.join([',\n' + x for x in self.pending_changes])
|
||||||
|
self.storage.append(s)
|
||||||
|
self.pending_changes = []
|
||||||
|
|
||||||
|
@locked
|
||||||
@profiler
|
@profiler
|
||||||
def _write(self):
|
def _write(self):
|
||||||
if threading.current_thread().daemon:
|
if threading.current_thread().daemon:
|
||||||
@@ -315,4 +392,5 @@ class JsonDB(Logger):
|
|||||||
return
|
return
|
||||||
json_str = self.dump(human_readable=not self.storage.is_encrypted())
|
json_str = self.dump(human_readable=not self.storage.is_encrypted())
|
||||||
self.storage.write(json_str)
|
self.storage.write(json_str)
|
||||||
|
self.pending_changes = []
|
||||||
self.set_modified(False)
|
self.set_modified(False)
|
||||||
|
|||||||
@@ -894,7 +894,8 @@ class Peer(Logger):
|
|||||||
"revocation_store": {},
|
"revocation_store": {},
|
||||||
"channel_type": channel_type,
|
"channel_type": channel_type,
|
||||||
}
|
}
|
||||||
return StoredDict(chan_dict, self.lnworker.db if self.lnworker else None, [])
|
# set db to None, because we do not want to write updates until channel is saved
|
||||||
|
return StoredDict(chan_dict, None, [])
|
||||||
|
|
||||||
async def on_open_channel(self, payload):
|
async def on_open_channel(self, payload):
|
||||||
"""Implements the channel acceptance flow.
|
"""Implements the channel acceptance flow.
|
||||||
|
|||||||
@@ -86,12 +86,10 @@ class WalletStorage(Logger):
|
|||||||
f.write(s)
|
f.write(s)
|
||||||
f.flush()
|
f.flush()
|
||||||
os.fsync(f.fileno())
|
os.fsync(f.fileno())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mode = os.stat(self.path).st_mode
|
mode = os.stat(self.path).st_mode
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
mode = stat.S_IREAD | stat.S_IWRITE
|
mode = stat.S_IREAD | stat.S_IWRITE
|
||||||
|
|
||||||
# assert that wallet file does not exist, to prevent wallet corruption (see issue #5082)
|
# assert that wallet file does not exist, to prevent wallet corruption (see issue #5082)
|
||||||
if not self.file_exists():
|
if not self.file_exists():
|
||||||
assert not os.path.exists(self.path)
|
assert not os.path.exists(self.path)
|
||||||
@@ -100,6 +98,15 @@ class WalletStorage(Logger):
|
|||||||
self._file_exists = True
|
self._file_exists = True
|
||||||
self.logger.info(f"saved {self.path}")
|
self.logger.info(f"saved {self.path}")
|
||||||
|
|
||||||
|
def append(self, data: str) -> None:
|
||||||
|
""" append data to file. for the moment, only non-encrypted file"""
|
||||||
|
assert not self.is_encrypted()
|
||||||
|
with open(self.path, "r+") as f:
|
||||||
|
f.seek(0, os.SEEK_END)
|
||||||
|
f.write(data)
|
||||||
|
f.flush()
|
||||||
|
os.fsync(f.fileno())
|
||||||
|
|
||||||
def file_exists(self) -> bool:
|
def file_exists(self) -> bool:
|
||||||
return self._file_exists
|
return self._file_exists
|
||||||
|
|
||||||
@@ -179,6 +186,7 @@ class WalletStorage(Logger):
|
|||||||
def encrypt_before_writing(self, plaintext: str) -> str:
|
def encrypt_before_writing(self, plaintext: str) -> str:
|
||||||
s = plaintext
|
s = plaintext
|
||||||
if self.pubkey:
|
if self.pubkey:
|
||||||
|
self.decrypted = plaintext
|
||||||
s = bytes(s, 'utf8')
|
s = bytes(s, 'utf8')
|
||||||
c = zlib.compress(s, level=zlib.Z_BEST_SPEED)
|
c = zlib.compress(s, level=zlib.Z_BEST_SPEED)
|
||||||
enc_magic = self._get_encryption_magic()
|
enc_magic = self._get_encryption_magic()
|
||||||
|
|||||||
@@ -2828,7 +2828,9 @@ class Abstract_Wallet(ABC, Logger, EventListener):
|
|||||||
self._update_password_for_keystore(old_pw, new_pw)
|
self._update_password_for_keystore(old_pw, new_pw)
|
||||||
encrypt_keystore = self.can_have_keystore_encryption()
|
encrypt_keystore = self.can_have_keystore_encryption()
|
||||||
self.db.set_keystore_encryption(bool(new_pw) and encrypt_keystore)
|
self.db.set_keystore_encryption(bool(new_pw) and encrypt_keystore)
|
||||||
self.save_db()
|
## save changes
|
||||||
|
if self.storage and self.storage.file_exists():
|
||||||
|
self.db._write()
|
||||||
self.unlock(None)
|
self.unlock(None)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user