1
0

lnhtlc: save logs and feeupdates

This commit is contained in:
Janus
2018-09-18 18:38:57 +02:00
committed by ThomasV
parent eca5545004
commit d5d9270d0c
5 changed files with 130 additions and 69 deletions

View File

@@ -16,7 +16,7 @@ from .lnutil import sign_and_get_sig_string
from .lnutil import make_htlc_tx_with_open_channel, make_commitment, make_received_htlc, make_offered_htlc
from .lnutil import HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT
from .lnutil import funding_output_script, extract_ctn_from_tx_and_chan
from .lnutil import LOCAL, REMOTE, SENT, RECEIVED
from .lnutil import LOCAL, REMOTE, SENT, RECEIVED, HTLCOwner
from .transaction import Transaction
@@ -34,12 +34,22 @@ FUNDEE_ACKED = FeeUpdateProgress.FUNDEE_ACKED
FUNDER_SIGNED = FeeUpdateProgress.FUNDER_SIGNED
COMMITTED = FeeUpdateProgress.COMMITTED
class FeeUpdate:
from collections import namedtuple
def __init__(self, chan, feerate):
self.rate = feerate
self.proposed = chan.remote_state.ctn if not chan.constraints.is_initiator else chan.local_state.ctn
self.progress = {FUNDEE_SIGNED: None, FUNDEE_ACKED: None, FUNDER_SIGNED: None, COMMITTED: None}
class FeeUpdate:
def __init__(self, chan, **kwargs):
if 'rate' in kwargs:
self.rate = kwargs['rate']
else:
assert False
if 'proposed' not in kwargs:
self.proposed = chan.remote_state.ctn if not chan.constraints.is_initiator else chan.local_state.ctn
else:
self.proposed = kwargs['proposed']
if 'progress' not in kwargs:
self.progress = {FUNDEE_SIGNED: None, FUNDEE_ACKED: None, FUNDER_SIGNED: None, COMMITTED: None}
else:
self.progress = {FeeUpdateProgress[x.partition('.')[2]]: y for x,y in kwargs['progress'].items()}
self.chan = chan
@property
@@ -65,30 +75,30 @@ class FeeUpdate:
if subject == LOCAL and not self.chan.constraints.is_initiator:
return self.rate
class UpdateAddHtlc:
def __init__(self, amount_msat, payment_hash, cltv_expiry):
self.amount_msat = amount_msat
self.payment_hash = payment_hash
self.cltv_expiry = cltv_expiry
def to_save(self):
return {'rate': self.rate, 'proposed': self.proposed, 'progress': self.progress}
# the height the htlc was locked in at, or None
self.locked_in = {LOCAL: None, REMOTE: None}
self.settled = {LOCAL: None, REMOTE: None}
self.htlc_id = None
def as_tuple(self):
return (self.htlc_id, self.amount_msat, self.payment_hash, self.cltv_expiry, self.locked_in[REMOTE], self.locked_in[LOCAL], self.settled)
def __hash__(self):
return hash(self.as_tuple())
def __eq__(self, o):
return type(o) is UpdateAddHtlc and self.as_tuple() == o.as_tuple()
def __repr__(self):
return "UpdateAddHtlc" + str(self.as_tuple())
class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'settled', 'locked_in', 'htlc_id'])):
__slots__ = ()
def __new__(cls, *args, **kwargs):
if len(args) > 0:
args = list(args)
if type(args[1]) is str:
args[1] = bfh(args[1])
args[3] = {HTLCOwner(int(x)): y for x,y in args[3].items()}
args[4] = {HTLCOwner(int(x)): y for x,y in args[4].items()}
return super().__new__(cls, *args)
if type(kwargs['payment_hash']) is str:
kwargs['payment_hash'] = bfh(kwargs['payment_hash'])
if 'locked_in' not in kwargs:
kwargs['locked_in'] = {LOCAL: None, REMOTE: None}
else:
kwargs['locked_in'] = {HTLCOwner(int(x)): y for x,y in kwargs['locked_in']}
if 'settled' not in kwargs:
kwargs['settled'] = {LOCAL: None, REMOTE: None}
else:
kwargs['settled'] = {HTLCOwner(int(x)): y for x,y in kwargs['settled']}
return super().__new__(cls, **kwargs)
is_key = lambda k: k.endswith("_basepoint") or k.endswith("_key")
@@ -155,10 +165,22 @@ class HTLCStateMachine(PrintError):
self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"])
self.log = {LOCAL: [], REMOTE: []}
for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
if strname not in state: continue
for typ,y in state[strname]:
if typ == "UpdateAddHtlc":
self.log[subject].append(UpdateAddHtlc(*decodeAll(y)))
elif typ == "SettleHtlc":
self.log[subject].append(SettleHtlc(*decodeAll(y)))
else:
assert False
self.name = name
self.fee_mgr = []
if 'fee_updates' in state:
for y in state['fee_updates']:
self.fee_mgr.append(FeeUpdate(self, **y))
self.local_commitment = self.pending_local_commitment
self.remote_commitment = self.pending_remote_commitment
@@ -190,13 +212,12 @@ class HTLCStateMachine(PrintError):
AddHTLC adds an HTLC to the state machine's local update log. This method
should be called when preparing to send an outgoing HTLC.
"""
assert type(htlc) is UpdateAddHtlc
assert type(htlc) is dict
htlc = UpdateAddHtlc(**htlc, htlc_id=self.local_state.next_htlc_id)
self.log[LOCAL].append(htlc)
self.print_error("add_htlc")
htlc_id = self.local_state.next_htlc_id
self.local_state=self.local_state._replace(next_htlc_id=htlc_id + 1)
htlc.htlc_id = htlc_id
return htlc_id
self.local_state=self.local_state._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id
def receive_htlc(self, htlc):
"""
@@ -204,13 +225,12 @@ class HTLCStateMachine(PrintError):
method should be called in response to receiving a new HTLC from the remote
party.
"""
self.print_error("receive_htlc")
assert type(htlc) is UpdateAddHtlc
assert type(htlc) is dict
htlc = UpdateAddHtlc(**htlc, htlc_id = self.remote_state.next_htlc_id)
self.log[REMOTE].append(htlc)
htlc_id = self.remote_state.next_htlc_id
self.remote_state=self.remote_state._replace(next_htlc_id=htlc_id + 1)
htlc.htlc_id = htlc_id
return htlc_id
self.print_error("receive_htlc")
self.remote_state=self.remote_state._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id
def sign_next_commitment(self):
"""
@@ -431,6 +451,9 @@ class HTLCStateMachine(PrintError):
amount_msat = self.local_state.amount_msat + (received_this_batch - sent_this_batch)
)
self.balance(LOCAL)
self.balance(REMOTE)
for pending_fee in self.fee_mgr:
if pending_fee.is_proposed():
if self.constraints.is_initiator:
@@ -441,6 +464,26 @@ class HTLCStateMachine(PrintError):
self.remote_commitment_to_be_revoked = prev_remote_commitment
return received_this_batch, sent_this_batch
def balance(self, subject):
initial = self.local_config.initial_msat if subject == LOCAL else self.remote_config.initial_msat
for x in self.log[-subject]:
if type(x) is not SettleHtlc: continue
htlc = self.lookup_htlc(self.log[subject], x.htlc_id)
htlc_height = htlc.settled[subject]
if htlc_height is not None and htlc_height <= self.current_height[subject]:
initial -= htlc.amount_msat
for x in self.log[subject]:
if type(x) is not SettleHtlc: continue
htlc = self.lookup_htlc(self.log[-subject], x.htlc_id)
htlc_height = htlc.settled[-subject]
if htlc_height is not None and htlc_height <= self.current_height[-subject]:
initial += htlc.amount_msat
assert initial == (self.local_state.amount_msat if subject == LOCAL else self.remote_state.amount_msat)
return initial
@staticmethod
def htlcsum(htlcs):
amount_unsettled = 0
@@ -611,13 +654,13 @@ class HTLCStateMachine(PrintError):
def update_fee(self, feerate):
if not self.constraints.is_initiator:
raise Exception("only initiator can update_fee, this counterparty is not initiator")
pending_fee = FeeUpdate(self, feerate)
pending_fee = FeeUpdate(self, rate=feerate)
self.fee_mgr.append(pending_fee)
def receive_update_fee(self, feerate):
if self.constraints.is_initiator:
raise Exception("only the non-initiator can receive_update_fee, this counterparty is initiator")
pending_fee = FeeUpdate(self, feerate)
pending_fee = FeeUpdate(self, rate=feerate)
self.fee_mgr.append(pending_fee)
def to_save(self):
@@ -632,6 +675,9 @@ class HTLCStateMachine(PrintError):
"funding_outpoint": self.funding_outpoint,
"node_id": self.node_id,
"remote_commitment_to_be_revoked": str(self.remote_commitment_to_be_revoked),
"remote_log": [(type(x).__name__, x) for x in self.log[REMOTE]],
"local_log": [(type(x).__name__, x) for x in self.log[LOCAL]],
"fee_updates": [x.to_save() for x in self.fee_mgr],
}
def serialize(self):
@@ -643,7 +689,13 @@ class HTLCStateMachine(PrintError):
return binascii.hexlify(o).decode("ascii")
if isinstance(o, RevocationStore):
return o.serialize()
if isinstance(o, SettleHtlc):
return json.dumps(('SettleHtlc', namedtuples_to_dict(o)))
if isinstance(o, UpdateAddHtlc):
return json.dumps(('UpdateAddHtlc', namedtuples_to_dict(o)))
return super(MyJsonEncoder, self)
for fee_upd in serialized_channel['fee_updates']:
fee_upd['progress'] = {str(k): v for k,v in fee_upd['progress'].items()}
dumped = MyJsonEncoder().encode(serialized_channel)
roundtripped = json.loads(dumped)
reconstructed = HTLCStateMachine(roundtripped)