Restructure wallet storage:
- Perform json deserializations in wallet_db - use StoredDict class that keeps tracks of its modifications
This commit is contained in:
@@ -1,14 +1,17 @@
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Sequence, Tuple, List, Dict
|
||||
from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
|
||||
|
||||
from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate
|
||||
from .util import bh2u, bfh
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .json_db import StoredDict
|
||||
|
||||
class HTLCManager:
|
||||
|
||||
def __init__(self, *, log=None, initial_feerate=None):
|
||||
if log is None:
|
||||
def __init__(self, log:'StoredDict', *, initial_feerate=None):
|
||||
|
||||
if len(log) == 0:
|
||||
initial = {
|
||||
'adds': {},
|
||||
'locked_in': {},
|
||||
@@ -17,33 +20,18 @@ class HTLCManager:
|
||||
'fee_updates': {}, # "side who initiated fee update" -> action -> list of FeeUpdates
|
||||
'revack_pending': False,
|
||||
'next_htlc_id': 0,
|
||||
'ctn': -1, # oldest unrevoked ctx of sub
|
||||
'ctn': -1, # oldest unrevoked ctx of sub
|
||||
}
|
||||
log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)}
|
||||
else:
|
||||
assert type(log) is dict
|
||||
log = {(HTLCOwner(int(k)) if k in ("-1", "1") else k): v
|
||||
for k, v in deepcopy(log).items()}
|
||||
for sub in (LOCAL, REMOTE):
|
||||
log[sub]['adds'] = {int(htlc_id): UpdateAddHtlc(*htlc) for htlc_id, htlc in log[sub]['adds'].items()}
|
||||
coerceHtlcOwner2IntMap = lambda ctns: {HTLCOwner(int(owner)): ctn for owner, ctn in ctns.items()}
|
||||
# "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn
|
||||
log[sub]['locked_in'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['locked_in'].items()}
|
||||
log[sub]['settles'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['settles'].items()}
|
||||
log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()}
|
||||
# "side who initiated fee update" -> action -> list of FeeUpdates
|
||||
log[sub]['fee_updates'] = { int(x): FeeUpdate(**fee_upd) for x,fee_upd in log[sub]['fee_updates'].items() }
|
||||
|
||||
if 'unacked_local_updates2' not in log:
|
||||
log[LOCAL] = deepcopy(initial)
|
||||
log[REMOTE] = deepcopy(initial)
|
||||
log['unacked_local_updates2'] = {}
|
||||
log['unacked_local_updates2'] = {int(ctn): [bfh(msg) for msg in messages]
|
||||
for ctn, messages in log['unacked_local_updates2'].items()}
|
||||
|
||||
# maybe bootstrap fee_updates if initial_feerate was provided
|
||||
if initial_feerate is not None:
|
||||
assert type(initial_feerate) is int
|
||||
for sub in (LOCAL, REMOTE):
|
||||
if not log[sub]['fee_updates']:
|
||||
log[sub]['fee_updates'][0] = FeeUpdate(initial_feerate, ctn_local=0, ctn_remote=0)
|
||||
log[sub]['fee_updates'][0] = FeeUpdate(rate=initial_feerate, ctn_local=0, ctn_remote=0)
|
||||
self.log = log
|
||||
|
||||
def ctn_latest(self, sub: HTLCOwner) -> int:
|
||||
@@ -66,20 +54,6 @@ class HTLCManager:
|
||||
def get_next_htlc_id(self, sub: HTLCOwner) -> int:
|
||||
return self.log[sub]['next_htlc_id']
|
||||
|
||||
def to_save(self):
|
||||
log = deepcopy(self.log)
|
||||
for sub in (LOCAL, REMOTE):
|
||||
# adds
|
||||
d = {}
|
||||
for htlc_id, htlc in log[sub]['adds'].items():
|
||||
d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:]
|
||||
log[sub]['adds'] = d
|
||||
# fee_updates
|
||||
log[sub]['fee_updates'] = { x:fee_upd.to_json() for x, fee_upd in self.log[sub]['fee_updates'].items() }
|
||||
log['unacked_local_updates2'] = {ctn: [bh2u(msg) for msg in messages]
|
||||
for ctn, messages in log['unacked_local_updates2'].items()}
|
||||
return log
|
||||
|
||||
##### Actions on channel:
|
||||
|
||||
def channel_open_finished(self):
|
||||
@@ -132,7 +106,7 @@ class HTLCManager:
|
||||
def _new_feeupdate(self, fee_update: FeeUpdate, subject: HTLCOwner) -> None:
|
||||
# overwrite last fee update if not yet committed to by anyone; otherwise append
|
||||
d = self.log[subject]['fee_updates']
|
||||
assert type(d) is dict
|
||||
#assert type(d) is StoredDict
|
||||
n = len(d)
|
||||
last_fee_update = d[n-1]
|
||||
if (last_fee_update.ctn_local is None or last_fee_update.ctn_local > self.ctn_latest(LOCAL)) \
|
||||
@@ -194,7 +168,7 @@ class HTLCManager:
|
||||
del self.log[REMOTE]['locked_in'][htlc_id]
|
||||
del self.log[REMOTE]['adds'][htlc_id]
|
||||
if self.log[REMOTE]['locked_in']:
|
||||
self.log[REMOTE]['next_htlc_id'] = max(self.log[REMOTE]['locked_in']) + 1
|
||||
self.log[REMOTE]['next_htlc_id'] = max([int(x) for x in self.log[REMOTE]['locked_in'].keys()]) + 1
|
||||
else:
|
||||
self.log[REMOTE]['next_htlc_id'] = 0
|
||||
# htlcs removed
|
||||
@@ -217,12 +191,14 @@ class HTLCManager:
|
||||
ctn_idx = self.ctn_latest(REMOTE)
|
||||
else:
|
||||
ctn_idx = self.ctn_latest(REMOTE) + 1
|
||||
if ctn_idx not in self.log['unacked_local_updates2']:
|
||||
self.log['unacked_local_updates2'][ctn_idx] = []
|
||||
self.log['unacked_local_updates2'][ctn_idx].append(raw_update_msg)
|
||||
l = self.log['unacked_local_updates2'].get(ctn_idx, [])
|
||||
l.append(raw_update_msg.hex())
|
||||
self.log['unacked_local_updates2'][ctn_idx] = l
|
||||
|
||||
def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]:
|
||||
return self.log['unacked_local_updates2']
|
||||
#return self.log['unacked_local_updates2']
|
||||
return {int(ctn): [bfh(msg) for msg in messages]
|
||||
for ctn, messages in self.log['unacked_local_updates2'].items()}
|
||||
|
||||
##### Queries re HTLCs:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user