lnchan: use NamedTuple for logs instead of dict with static keys (adds, locked_in, settles, fails)
This commit is contained in:
@@ -26,7 +26,7 @@ from collections import namedtuple, defaultdict
|
|||||||
import binascii
|
import binascii
|
||||||
import json
|
import json
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Optional, Dict, List, Tuple
|
from typing import Optional, Dict, List, Tuple, NamedTuple, Set
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from .util import bfh, PrintError, bh2u
|
from .util import bfh, PrintError, bh2u
|
||||||
@@ -121,6 +121,20 @@ def str_bytes_dict_from_save(x):
|
|||||||
def str_bytes_dict_to_save(x):
|
def str_bytes_dict_to_save(x):
|
||||||
return {str(k): bh2u(v) for k, v in x.items()}
|
return {str(k): bh2u(v) for k, v in x.items()}
|
||||||
|
|
||||||
|
class HtlcChanges(NamedTuple):
|
||||||
|
# ints are htlc ids
|
||||||
|
adds: Dict[int, UpdateAddHtlc]
|
||||||
|
settles: Set[int]
|
||||||
|
fails: Set[int]
|
||||||
|
locked_in: Set[int]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new():
|
||||||
|
"""
|
||||||
|
Since we can't use default arguments for these types (they would be shared among instances)
|
||||||
|
"""
|
||||||
|
return HtlcChanges({}, set(), set(), set())
|
||||||
|
|
||||||
class Channel(PrintError):
|
class Channel(PrintError):
|
||||||
def diagnostic_name(self):
|
def diagnostic_name(self):
|
||||||
if self.name:
|
if self.name:
|
||||||
@@ -158,18 +172,12 @@ class Channel(PrintError):
|
|||||||
# any past commitment transaction and use that instead; until then...
|
# any past commitment transaction and use that instead; until then...
|
||||||
self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"])
|
self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"])
|
||||||
|
|
||||||
template = lambda: {
|
self.log = {LOCAL: HtlcChanges.new(), REMOTE: HtlcChanges.new()}
|
||||||
'adds': {}, # Dict[HTLC_ID, UpdateAddHtlc]
|
|
||||||
'settles': [], # List[HTLC_ID]
|
|
||||||
'fails': [], # List[HTLC_ID]
|
|
||||||
'locked_in': [], # List[HTLC_ID]
|
|
||||||
}
|
|
||||||
self.log = {LOCAL: template(), REMOTE: template()}
|
|
||||||
for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
|
for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
|
||||||
if strname not in state: continue
|
if strname not in state: continue
|
||||||
for y in state[strname]:
|
for y in state[strname]:
|
||||||
htlc = UpdateAddHtlc(**y)
|
htlc = UpdateAddHtlc(**y)
|
||||||
self.log[subject]['adds'][htlc.htlc_id] = htlc
|
self.log[subject].adds[htlc.htlc_id] = htlc
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
@@ -185,6 +193,9 @@ class Channel(PrintError):
|
|||||||
|
|
||||||
self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])}
|
self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])}
|
||||||
|
|
||||||
|
for sub in (LOCAL, REMOTE):
|
||||||
|
self.log[sub].locked_in.update(self.log[sub].adds.keys())
|
||||||
|
|
||||||
def set_state(self, state: str):
|
def set_state(self, state: str):
|
||||||
self._state = state
|
self._state = state
|
||||||
|
|
||||||
@@ -232,7 +243,7 @@ class Channel(PrintError):
|
|||||||
assert type(htlc) is dict
|
assert type(htlc) is dict
|
||||||
self._check_can_pay(htlc['amount_msat'])
|
self._check_can_pay(htlc['amount_msat'])
|
||||||
htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id)
|
htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id)
|
||||||
self.log[LOCAL]['adds'][htlc.htlc_id] = htlc
|
self.log[LOCAL].adds[htlc.htlc_id] = htlc
|
||||||
self.print_error("add_htlc")
|
self.print_error("add_htlc")
|
||||||
self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1)
|
self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1)
|
||||||
return htlc.htlc_id
|
return htlc.htlc_id
|
||||||
@@ -251,7 +262,7 @@ class Channel(PrintError):
|
|||||||
raise RemoteMisbehaving('Remote dipped below channel reserve.' +\
|
raise RemoteMisbehaving('Remote dipped below channel reserve.' +\
|
||||||
f' Available at remote: {self.available_to_spend(REMOTE)},' +\
|
f' Available at remote: {self.available_to_spend(REMOTE)},' +\
|
||||||
f' HTLC amount: {htlc.amount_msat}')
|
f' HTLC amount: {htlc.amount_msat}')
|
||||||
adds = self.log[REMOTE]['adds']
|
adds = self.log[REMOTE].adds
|
||||||
adds[htlc.htlc_id] = htlc
|
adds[htlc.htlc_id] = htlc
|
||||||
self.print_error("receive_htlc")
|
self.print_error("receive_htlc")
|
||||||
self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1)
|
self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1)
|
||||||
@@ -309,11 +320,11 @@ class Channel(PrintError):
|
|||||||
for sub in (LOCAL, REMOTE):
|
for sub in (LOCAL, REMOTE):
|
||||||
log = self.log[sub]
|
log = self.log[sub]
|
||||||
yield (sub, deepcopy(log))
|
yield (sub, deepcopy(log))
|
||||||
for htlc_id in log['fails']:
|
for htlc_id in log.fails:
|
||||||
log['adds'].pop(htlc_id)
|
log.adds.pop(htlc_id)
|
||||||
log['fails'].clear()
|
log.fails.clear()
|
||||||
|
|
||||||
self.log[subject]['locked_in'] |= self.log[subject]['adds'].keys()
|
self.log[subject].locked_in.update(self.log[subject].adds.keys())
|
||||||
|
|
||||||
def receive_new_commitment(self, sig, htlc_sigs):
|
def receive_new_commitment(self, sig, htlc_sigs):
|
||||||
"""
|
"""
|
||||||
@@ -474,11 +485,11 @@ class Channel(PrintError):
|
|||||||
"""
|
"""
|
||||||
old_amount = htlcsum(self.htlcs(subject, False))
|
old_amount = htlcsum(self.htlcs(subject, False))
|
||||||
|
|
||||||
for htlc_id in self.log[subject]['settles']:
|
for htlc_id in self.log[subject].settles:
|
||||||
adds = self.log[subject]['adds']
|
adds = self.log[subject].adds
|
||||||
htlc = adds.pop(htlc_id)
|
htlc = adds.pop(htlc_id)
|
||||||
self.settled[subject].append(htlc.amount_msat)
|
self.settled[subject].append(htlc.amount_msat)
|
||||||
self.log[subject]['settles'].clear()
|
self.log[subject].settles.clear()
|
||||||
|
|
||||||
return old_amount - htlcsum(self.htlcs(subject, False))
|
return old_amount - htlcsum(self.htlcs(subject, False))
|
||||||
|
|
||||||
@@ -533,7 +544,7 @@ class Channel(PrintError):
|
|||||||
pending outgoing HTLCs, is used in the UI.
|
pending outgoing HTLCs, is used in the UI.
|
||||||
"""
|
"""
|
||||||
return self.balance(subject)\
|
return self.balance(subject)\
|
||||||
- htlcsum(self.log[subject]['adds'].values())
|
- htlcsum(self.log[subject].adds.values())
|
||||||
|
|
||||||
def available_to_spend(self, subject):
|
def available_to_spend(self, subject):
|
||||||
"""
|
"""
|
||||||
@@ -541,7 +552,7 @@ class Channel(PrintError):
|
|||||||
not be used in the UI cause it fluctuates (commit fee)
|
not be used in the UI cause it fluctuates (commit fee)
|
||||||
"""
|
"""
|
||||||
return self.balance_minus_outgoing_htlcs(subject)\
|
return self.balance_minus_outgoing_htlcs(subject)\
|
||||||
- htlcsum(self.log[subject]['adds'].values())\
|
- htlcsum(self.log[subject].adds.values())\
|
||||||
- self.config[-subject].reserve_sat * 1000\
|
- self.config[-subject].reserve_sat * 1000\
|
||||||
- calc_onchain_fees(
|
- calc_onchain_fees(
|
||||||
# TODO should we include a potential new htlc, when we are called from receive_htlc?
|
# TODO should we include a potential new htlc, when we are called from receive_htlc?
|
||||||
@@ -601,10 +612,10 @@ class Channel(PrintError):
|
|||||||
"""
|
"""
|
||||||
update_log = self.log[subject]
|
update_log = self.log[subject]
|
||||||
res = []
|
res = []
|
||||||
for htlc in update_log['adds'].values():
|
for htlc in update_log.adds.values():
|
||||||
locked_in = htlc.htlc_id in update_log['locked_in']
|
locked_in = htlc.htlc_id in update_log.locked_in
|
||||||
settled = htlc.htlc_id in update_log['settles']
|
settled = htlc.htlc_id in update_log.settles
|
||||||
failed = htlc.htlc_id in update_log['fails']
|
failed = htlc.htlc_id in update_log.fails
|
||||||
if not locked_in:
|
if not locked_in:
|
||||||
continue
|
continue
|
||||||
if only_pending == (settled or failed):
|
if only_pending == (settled or failed):
|
||||||
@@ -617,25 +628,33 @@ class Channel(PrintError):
|
|||||||
SettleHTLC attempts to settle an existing outstanding received HTLC.
|
SettleHTLC attempts to settle an existing outstanding received HTLC.
|
||||||
"""
|
"""
|
||||||
self.print_error("settle_htlc")
|
self.print_error("settle_htlc")
|
||||||
htlc = self.log[REMOTE]['adds'][htlc_id]
|
log = self.log[REMOTE]
|
||||||
|
htlc = log.adds[htlc_id]
|
||||||
assert htlc.payment_hash == sha256(preimage)
|
assert htlc.payment_hash == sha256(preimage)
|
||||||
self.log[REMOTE]['settles'].append(htlc_id)
|
assert htlc_id not in log.settles
|
||||||
|
log.settles.add(htlc_id)
|
||||||
# not saving preimage because it's already saved in LNWorker.invoices
|
# not saving preimage because it's already saved in LNWorker.invoices
|
||||||
|
|
||||||
def receive_htlc_settle(self, preimage, htlc_id):
|
def receive_htlc_settle(self, preimage, htlc_id):
|
||||||
self.print_error("receive_htlc_settle")
|
self.print_error("receive_htlc_settle")
|
||||||
htlc = self.log[LOCAL]['adds'][htlc_id]
|
log = self.log[LOCAL]
|
||||||
|
htlc = log.adds[htlc_id]
|
||||||
assert htlc.payment_hash == sha256(preimage)
|
assert htlc.payment_hash == sha256(preimage)
|
||||||
self.log[LOCAL]['settles'].append(htlc_id)
|
assert htlc_id not in log.settles
|
||||||
|
log.settles.add(htlc_id)
|
||||||
# we don't save the preimage because we don't need to forward it anyway
|
# we don't save the preimage because we don't need to forward it anyway
|
||||||
|
|
||||||
def fail_htlc(self, htlc_id):
|
def fail_htlc(self, htlc_id):
|
||||||
self.print_error("fail_htlc")
|
self.print_error("fail_htlc")
|
||||||
self.log[REMOTE]['fails'].append(htlc_id)
|
log = self.log[REMOTE]
|
||||||
|
assert htlc_id not in log.fails
|
||||||
|
log.fails.add(htlc_id)
|
||||||
|
|
||||||
def receive_fail_htlc(self, htlc_id):
|
def receive_fail_htlc(self, htlc_id):
|
||||||
self.print_error("receive_fail_htlc")
|
self.print_error("receive_fail_htlc")
|
||||||
self.log[LOCAL]['fails'].append(htlc_id)
|
log = self.log[LOCAL]
|
||||||
|
assert htlc_id not in log.fails
|
||||||
|
log.fails.add(htlc_id)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_height(self):
|
def current_height(self):
|
||||||
@@ -666,8 +685,8 @@ class Channel(PrintError):
|
|||||||
removed = []
|
removed = []
|
||||||
htlcs = []
|
htlcs = []
|
||||||
log = self.log[subject]
|
log = self.log[subject]
|
||||||
for htlc_id, i in log['adds'].items():
|
for i in log.adds.values():
|
||||||
locked_in = htlc_id in log['locked_in']
|
locked_in = i.htlc_id in log.locked_in
|
||||||
if locked_in:
|
if locked_in:
|
||||||
htlcs.append(i._asdict())
|
htlcs.append(i._asdict())
|
||||||
else:
|
else:
|
||||||
@@ -710,18 +729,26 @@ class Channel(PrintError):
|
|||||||
|
|
||||||
def serialize(self):
|
def serialize(self):
|
||||||
namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()}
|
namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()}
|
||||||
serialized_channel = {k: namedtuples_to_dict(v) if isinstance(v, tuple) else v for k, v in self.to_save().items()}
|
serialized_channel = {}
|
||||||
|
to_save_ref = self.to_save()
|
||||||
|
for k, v in to_save_ref.items():
|
||||||
|
if isinstance(v, tuple):
|
||||||
|
serialized_channel[k] = namedtuples_to_dict(v)
|
||||||
|
else:
|
||||||
|
serialized_channel[k] = v
|
||||||
dumped = ChannelJsonEncoder().encode(serialized_channel)
|
dumped = ChannelJsonEncoder().encode(serialized_channel)
|
||||||
roundtripped = json.loads(dumped)
|
roundtripped = json.loads(dumped)
|
||||||
reconstructed = Channel(roundtripped)
|
reconstructed = Channel(roundtripped)
|
||||||
if reconstructed.to_save() != self.to_save():
|
to_save_new = reconstructed.to_save()
|
||||||
from pprint import pformat
|
if to_save_new != to_save_ref:
|
||||||
|
from pprint import PrettyPrinter
|
||||||
|
pp = PrettyPrinter(indent=168)
|
||||||
try:
|
try:
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise Exception("Channels did not roundtrip serialization without changes:\n" + pformat(reconstructed.to_save()) + "\n" + pformat(self.to_save()))
|
raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(to_save_ref) + "\n" + pp.pformat(to_save_new))
|
||||||
else:
|
else:
|
||||||
raise Exception("Channels did not roundtrip serialization without changes:\n" + pformat(DeepDiff(reconstructed.to_save(), self.to_save())))
|
raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(DeepDiff(to_save_ref, to_save_new)))
|
||||||
return roundtripped
|
return roundtripped
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|||||||
@@ -183,7 +183,7 @@ class TestChannel(unittest.TestCase):
|
|||||||
|
|
||||||
self.bob_pending_remote_balance = after
|
self.bob_pending_remote_balance = after
|
||||||
|
|
||||||
self.htlc = self.bob_channel.log[lnutil.REMOTE]['adds'][0]
|
self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0]
|
||||||
|
|
||||||
def test_SimpleAddSettleWorkflow(self):
|
def test_SimpleAddSettleWorkflow(self):
|
||||||
alice_channel, bob_channel = self.alice_channel, self.bob_channel
|
alice_channel, bob_channel = self.alice_channel, self.bob_channel
|
||||||
@@ -217,6 +217,10 @@ class TestChannel(unittest.TestCase):
|
|||||||
# forward since she's sending an outgoing HTLC.
|
# forward since she's sending an outgoing HTLC.
|
||||||
alice_channel.receive_revocation(bobRevocation)
|
alice_channel.receive_revocation(bobRevocation)
|
||||||
|
|
||||||
|
# test serializing with locked_in htlc
|
||||||
|
self.assertEqual(len(alice_channel.to_save()['local_log']), 1)
|
||||||
|
alice_channel.serialize()
|
||||||
|
|
||||||
# Alice then processes bob's signature, and since she just received
|
# Alice then processes bob's signature, and since she just received
|
||||||
# the revocation, she expect this signature to cover everything up to
|
# the revocation, she expect this signature to cover everything up to
|
||||||
# the point where she sent her signature, including the HTLC.
|
# the point where she sent her signature, including the HTLC.
|
||||||
|
|||||||
Reference in New Issue
Block a user