1
0

Restructure wallet storage:

- Perform json deserializations in wallet_db
 - use StoredDict class that keeps tracks of its modifications
This commit is contained in:
ThomasV
2020-02-04 13:35:58 +01:00
parent 0a9e7cb04e
commit dbceed2647
14 changed files with 303 additions and 291 deletions

View File

@@ -1,14 +1,17 @@
from copy import deepcopy
from typing import Optional, Sequence, Tuple, List, Dict
from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate
from .util import bh2u, bfh
if TYPE_CHECKING:
from .json_db import StoredDict
class HTLCManager:
def __init__(self, *, log=None, initial_feerate=None):
if log is None:
def __init__(self, log:'StoredDict', *, initial_feerate=None):
if len(log) == 0:
initial = {
'adds': {},
'locked_in': {},
@@ -17,33 +20,18 @@ class HTLCManager:
'fee_updates': {}, # "side who initiated fee update" -> action -> list of FeeUpdates
'revack_pending': False,
'next_htlc_id': 0,
'ctn': -1, # oldest unrevoked ctx of sub
'ctn': -1, # oldest unrevoked ctx of sub
}
log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)}
else:
assert type(log) is dict
log = {(HTLCOwner(int(k)) if k in ("-1", "1") else k): v
for k, v in deepcopy(log).items()}
for sub in (LOCAL, REMOTE):
log[sub]['adds'] = {int(htlc_id): UpdateAddHtlc(*htlc) for htlc_id, htlc in log[sub]['adds'].items()}
coerceHtlcOwner2IntMap = lambda ctns: {HTLCOwner(int(owner)): ctn for owner, ctn in ctns.items()}
# "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn
log[sub]['locked_in'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['locked_in'].items()}
log[sub]['settles'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['settles'].items()}
log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()}
# "side who initiated fee update" -> action -> list of FeeUpdates
log[sub]['fee_updates'] = { int(x): FeeUpdate(**fee_upd) for x,fee_upd in log[sub]['fee_updates'].items() }
if 'unacked_local_updates2' not in log:
log[LOCAL] = deepcopy(initial)
log[REMOTE] = deepcopy(initial)
log['unacked_local_updates2'] = {}
log['unacked_local_updates2'] = {int(ctn): [bfh(msg) for msg in messages]
for ctn, messages in log['unacked_local_updates2'].items()}
# maybe bootstrap fee_updates if initial_feerate was provided
if initial_feerate is not None:
assert type(initial_feerate) is int
for sub in (LOCAL, REMOTE):
if not log[sub]['fee_updates']:
log[sub]['fee_updates'][0] = FeeUpdate(initial_feerate, ctn_local=0, ctn_remote=0)
log[sub]['fee_updates'][0] = FeeUpdate(rate=initial_feerate, ctn_local=0, ctn_remote=0)
self.log = log
def ctn_latest(self, sub: HTLCOwner) -> int:
@@ -66,20 +54,6 @@ class HTLCManager:
def get_next_htlc_id(self, sub: HTLCOwner) -> int:
return self.log[sub]['next_htlc_id']
def to_save(self):
log = deepcopy(self.log)
for sub in (LOCAL, REMOTE):
# adds
d = {}
for htlc_id, htlc in log[sub]['adds'].items():
d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:]
log[sub]['adds'] = d
# fee_updates
log[sub]['fee_updates'] = { x:fee_upd.to_json() for x, fee_upd in self.log[sub]['fee_updates'].items() }
log['unacked_local_updates2'] = {ctn: [bh2u(msg) for msg in messages]
for ctn, messages in log['unacked_local_updates2'].items()}
return log
##### Actions on channel:
def channel_open_finished(self):
@@ -132,7 +106,7 @@ class HTLCManager:
def _new_feeupdate(self, fee_update: FeeUpdate, subject: HTLCOwner) -> None:
# overwrite last fee update if not yet committed to by anyone; otherwise append
d = self.log[subject]['fee_updates']
assert type(d) is dict
#assert type(d) is StoredDict
n = len(d)
last_fee_update = d[n-1]
if (last_fee_update.ctn_local is None or last_fee_update.ctn_local > self.ctn_latest(LOCAL)) \
@@ -194,7 +168,7 @@ class HTLCManager:
del self.log[REMOTE]['locked_in'][htlc_id]
del self.log[REMOTE]['adds'][htlc_id]
if self.log[REMOTE]['locked_in']:
self.log[REMOTE]['next_htlc_id'] = max(self.log[REMOTE]['locked_in']) + 1
self.log[REMOTE]['next_htlc_id'] = max([int(x) for x in self.log[REMOTE]['locked_in'].keys()]) + 1
else:
self.log[REMOTE]['next_htlc_id'] = 0
# htlcs removed
@@ -217,12 +191,14 @@ class HTLCManager:
ctn_idx = self.ctn_latest(REMOTE)
else:
ctn_idx = self.ctn_latest(REMOTE) + 1
if ctn_idx not in self.log['unacked_local_updates2']:
self.log['unacked_local_updates2'][ctn_idx] = []
self.log['unacked_local_updates2'][ctn_idx].append(raw_update_msg)
l = self.log['unacked_local_updates2'].get(ctn_idx, [])
l.append(raw_update_msg.hex())
self.log['unacked_local_updates2'][ctn_idx] = l
def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]:
return self.log['unacked_local_updates2']
#return self.log['unacked_local_updates2']
return {int(ctn): [bfh(msg) for msg in messages]
for ctn, messages in self.log['unacked_local_updates2'].items()}
##### Queries re HTLCs: