1
0

Merge pull request #9848 from SomberNight/202505_refactor_lnchannel_ctx_updates

lnchannel: refactor can_send_ctx_updates
This commit is contained in:
ghost43
2025-05-21 15:49:34 +00:00
committed by GitHub
2 changed files with 98 additions and 50 deletions

View File

@@ -1054,18 +1054,21 @@ class Channel(AbstractChannel):
def set_can_send_ctx_updates(self, b: bool) -> None:
self._can_send_ctx_updates = b
def can_send_ctx_updates(self) -> bool:
"""Whether we can send update_fee, update_*_htlc changes to the remote."""
def can_update_ctx(self, *, proposer: HTLCOwner) -> bool:
"""Whether proposer is allowed to send commitment_signed, revoke_and_ack,
and update_* messages.
"""
if self.get_state() not in (ChannelState.OPEN, ChannelState.SHUTDOWN):
return False
if self.peer_state != PeerState.GOOD:
return False
if not self._can_send_ctx_updates:
return False
if proposer == LOCAL:
if not self._can_send_ctx_updates:
return False
return True
def can_send_update_add_htlc(self) -> bool:
return self.can_send_ctx_updates() and self.is_open()
return self.can_update_ctx(proposer=LOCAL) and self.is_open()
def is_frozen_for_sending(self) -> bool:
if self.lnworker and self.lnworker.uses_trampoline() and not self.lnworker.is_trampoline_peer(self.node_id):
@@ -1096,10 +1099,10 @@ class Channel(AbstractChannel):
ctn = self.get_next_ctn(htlc_receiver)
chan_config = self.config[htlc_receiver]
if self.get_state() != ChannelState.OPEN:
raise PaymentFailure('Channel not open', self.get_state())
raise PaymentFailure(f"Channel not open. {self.get_state()!r}")
if not self.can_update_ctx(proposer=htlc_proposer):
raise PaymentFailure(f"cannot update channel. {self.get_state()!r} {self.peer_state!r}")
if htlc_proposer == LOCAL:
if not self.can_send_ctx_updates():
raise PaymentFailure('Channel cannot send ctx updates')
if not self.can_send_update_add_htlc():
raise PaymentFailure('Channel cannot add htlc')
@@ -1239,6 +1242,7 @@ class Channel(AbstractChannel):
# TODO: when more channel types are supported, this method should depend on channel type
next_remote_ctn = self.get_next_ctn(REMOTE)
self.logger.info(f"sign_next_commitment. ctn={next_remote_ctn}")
assert not self.is_closed(), self.get_state()
pending_remote_commitment = self.get_next_commitment(REMOTE)
sig_64 = sign_and_get_sig_string(pending_remote_commitment, self.config[LOCAL], self.config[REMOTE])
@@ -1286,6 +1290,7 @@ class Channel(AbstractChannel):
# TODO: when more channel types are supported, this method should depend on channel type
next_local_ctn = self.get_next_ctn(LOCAL)
self.logger.info(f"receive_new_commitment. ctn={next_local_ctn}, len(htlc_sigs)={len(htlc_sigs)}")
assert not self.is_closed(), self.get_state()
assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes
@@ -1364,6 +1369,7 @@ class Channel(AbstractChannel):
def revoke_current_commitment(self):
self.logger.info("revoke_current_commitment")
assert not self.is_closed(), self.get_state()
new_ctn = self.get_latest_ctn(LOCAL)
new_ctx = self.get_latest_commitment(LOCAL)
if not self.signature_fits(new_ctx):
@@ -1377,6 +1383,7 @@ class Channel(AbstractChannel):
def receive_revocation(self, revocation: RevokeAndAck):
self.logger.info("receive_revocation")
assert not self.is_closed(), self.get_state()
new_ctn = self.get_latest_ctn(REMOTE)
cur_point = self.config[REMOTE].current_per_commitment_point
derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True)
@@ -1689,7 +1696,7 @@ class Channel(AbstractChannel):
Action must be initiated by LOCAL.
"""
self.logger.info("settle_htlc")
assert self.can_send_ctx_updates(), f"cannot update channel. {self.get_state()!r} {self.peer_state!r}"
assert self.can_update_ctx(proposer=LOCAL), f"cannot update channel. {self.get_state()!r} {self.peer_state!r}"
htlc = self.hm.get_htlc_by_id(REMOTE, htlc_id)
if htlc.payment_hash != sha256(preimage):
raise Exception("incorrect preimage for HTLC")
@@ -1706,6 +1713,7 @@ class Channel(AbstractChannel):
Action must be initiated by REMOTE.
"""
self.logger.info("receive_htlc_settle")
assert self.can_update_ctx(proposer=REMOTE), f"cannot update channel. {self.get_state()!r} {self.peer_state!r}"
htlc = self.hm.get_htlc_by_id(LOCAL, htlc_id)
if htlc.payment_hash != sha256(preimage):
raise RemoteMisbehaving("received incorrect preimage for HTLC")
@@ -1718,7 +1726,7 @@ class Channel(AbstractChannel):
Action must be initiated by LOCAL.
"""
self.logger.info("fail_htlc")
assert self.can_send_ctx_updates(), f"cannot update channel. {self.get_state()!r} {self.peer_state!r}"
assert self.can_update_ctx(proposer=LOCAL), f"cannot update channel. {self.get_state()!r} {self.peer_state!r}"
with self.db_lock:
self.hm.send_fail(htlc_id)
@@ -1729,6 +1737,7 @@ class Channel(AbstractChannel):
Action must be initiated by REMOTE.
"""
self.logger.info("receive_fail_htlc")
assert self.can_update_ctx(proposer=REMOTE), f"cannot update channel. {self.get_state()!r} {self.peer_state!r}"
with self.db_lock:
self.hm.recv_fail(htlc_id)
self._receive_fail_reasons[htlc_id] = (error_bytes, reason)
@@ -1762,9 +1771,9 @@ class Channel(AbstractChannel):
if remainder < 0:
raise Exception(f"Cannot update_fee. {sender} tried to update fee but they cannot afford it. "
f"Their balance would go below reserve: {remainder} msat missing.")
assert self.can_update_ctx(proposer=LOCAL if from_us else REMOTE), f"cannot update channel. {self.get_state()!r} {self.peer_state!r}. {from_us=}"
with self.db_lock:
if from_us:
assert self.can_send_ctx_updates(), f"cannot update channel. {self.get_state()!r} {self.peer_state!r}"
self.hm.send_update_fee(feerate)
else:
self.hm.recv_update_fee(feerate)

View File

@@ -1465,6 +1465,11 @@ class Peer(Logger, EventListener):
f'channel_reestablish ({chan.get_id_for_log()}): received channel_reestablish with '
f'(their_next_local_ctn={their_next_local_ctn}, '
f'their_oldest_unrevoked_remote_ctn={their_oldest_unrevoked_remote_ctn})')
if chan.get_state() >= ChannelState.CLOSED:
self.logger.warning(
f"on_channel_reestablish. dropping message. illegal action. "
f"chan={chan.get_id_for_log()}. {chan.get_state()=!r}. {chan.peer_state=!r}")
return
# sanity checks of received values
if their_next_local_ctn < 0:
raise RemoteMisbehaving(f"channel reestablish: their_next_local_ctn < 0")
@@ -1720,6 +1725,11 @@ class Peer(Logger, EventListener):
self.logger.info(f"on_channel_ready. channel: {chan.channel_id.hex()}")
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received channel_ready in unexpected {chan.peer_state=!r}")
if chan.is_closed():
self.logger.warning(
f"on_channel_ready. dropping message. illegal action. "
f"chan={chan.get_id_for_log()}. {chan.get_state()=!r}. {chan.peer_state=!r}")
return
# save remote alias for use in invoices
scid_alias = payload.get('channel_ready_tlvs', {}).get('short_channel_id', {}).get('alias')
if scid_alias:
@@ -1844,14 +1854,17 @@ class Peer(Logger, EventListener):
htlc_id = payload["id"]
reason = payload["reason"]
self.logger.info(f"on_update_fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received update_fail_htlc in unexpected {chan.peer_state=!r}")
if not chan.can_update_ctx(proposer=REMOTE):
self.logger.warning(
f"on_update_fail_htlc. dropping message. illegal action. "
f"chan={chan.get_id_for_log()}. {htlc_id=}. {chan.get_state()=!r}. {chan.peer_state=!r}")
return
chan.receive_fail_htlc(htlc_id, error_bytes=reason) # TODO handle exc and maybe fail channel (e.g. bad htlc_id)
self.maybe_send_commitment(chan)
def maybe_send_commitment(self, chan: Channel) -> bool:
assert util.get_running_loop() == util.get_asyncio_loop(), f"this must be run on the asyncio thread!"
if chan.is_closed():
if not chan.can_update_ctx(proposer=LOCAL):
return False
# REMOTE should revoke first before we can sign a new ctx
if chan.hm.is_revack_pending(REMOTE):
@@ -1925,6 +1938,7 @@ class Peer(Logger, EventListener):
onion: OnionPacket,
session_key: Optional[bytes] = None,
) -> UpdateAddHtlc:
assert chan.can_send_update_add_htlc(), f"cannot send updates: {chan.short_channel_id}"
htlc = UpdateAddHtlc(amount_msat=amount_msat, payment_hash=payment_hash, cltv_abs=cltv_abs, timestamp=int(time.time()))
htlc = chan.add_htlc(htlc)
if session_key:
@@ -1975,8 +1989,8 @@ class Peer(Logger, EventListener):
)
return htlc
def send_revoke_and_ack(self, chan: Channel):
if chan.is_closed():
def send_revoke_and_ack(self, chan: Channel) -> None:
if not chan.can_update_ctx(proposer=LOCAL):
return
self.logger.info(f'send_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(LOCAL)}')
rev = chan.revoke_current_commitment()
@@ -1987,12 +2001,13 @@ class Peer(Logger, EventListener):
next_per_commitment_point=rev.next_per_commitment_point)
self.maybe_send_commitment(chan)
def on_commitment_signed(self, chan: Channel, payload):
if chan.peer_state == PeerState.BAD:
return
def on_commitment_signed(self, chan: Channel, payload) -> None:
self.logger.info(f'on_commitment_signed. chan {chan.short_channel_id}. ctn: {chan.get_next_ctn(LOCAL)}.')
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received commitment_signed in unexpected {chan.peer_state=!r}")
if not chan.can_update_ctx(proposer=REMOTE):
self.logger.warning(
f"on_commitment_signed. dropping message. illegal action. "
f"chan={chan.get_id_for_log()}. {chan.get_state()=!r}. {chan.peer_state=!r}")
return
# make sure there were changes to the ctx, otherwise the remote peer is misbehaving
if not chan.has_pending_changes(LOCAL):
# TODO if feerate changed A->B->A; so there were updates but the value is identical,
@@ -2014,8 +2029,11 @@ class Peer(Logger, EventListener):
payment_hash = sha256(preimage)
htlc_id = payload["id"]
self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received update_fulfill_htlc in unexpected {chan.peer_state=!r}")
if not chan.can_update_ctx(proposer=REMOTE):
self.logger.warning(
f"on_update_fulfill_htlc. dropping message. illegal action. "
f"chan={chan.get_id_for_log()}. {htlc_id=}. {chan.get_state()=!r}. {chan.peer_state=!r}")
return
chan.receive_htlc_settle(preimage, htlc_id) # TODO handle exc and maybe fail channel (e.g. bad htlc_id)
self.lnworker.save_preimage(payment_hash, preimage)
self.maybe_send_commitment(chan)
@@ -2025,8 +2043,11 @@ class Peer(Logger, EventListener):
failure_code = payload["failure_code"]
self.logger.info(f"on_update_fail_malformed_htlc. chan {chan.get_id_for_log()}. "
f"htlc_id {htlc_id}. failure_code={failure_code}")
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received update_fail_malformed_htlc in unexpected {chan.peer_state=!r}")
if not chan.can_update_ctx(proposer=REMOTE):
self.logger.warning(
f"on_update_fail_malformed_htlc. dropping message. illegal action. "
f"chan={chan.get_id_for_log()}. {htlc_id=}. {chan.get_state()=!r}. {chan.peer_state=!r}")
return
if failure_code & OnionFailureCodeMetaFlag.BADONION == 0:
self.schedule_force_closing(chan.channel_id)
raise RemoteMisbehaving(f"received update_fail_malformed_htlc with unexpected failure code: {failure_code}")
@@ -2049,8 +2070,11 @@ class Peer(Logger, EventListener):
self.logger.info(f"on_update_add_htlc. chan {chan.short_channel_id}. htlc={str(htlc)}")
if chan.get_state() != ChannelState.OPEN:
raise RemoteMisbehaving(f"received update_add_htlc while chan.get_state() != OPEN. state was {chan.get_state()!r}")
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received update_add_htlc in unexpected {chan.peer_state=!r}")
if not chan.can_update_ctx(proposer=REMOTE):
self.logger.warning(
f"on_update_add_htlc. dropping message. illegal action. "
f"chan={chan.get_id_for_log()}. {htlc_id=}. {chan.get_state()=!r}. {chan.peer_state=!r}")
return
if cltv_abs > bitcoin.NLOCKTIME_BLOCKHEIGHT_MAX:
self.schedule_force_closing(chan.channel_id)
raise RemoteMisbehaving(f"received update_add_htlc with {cltv_abs=} > BLOCKHEIGHT_MAX")
@@ -2540,7 +2564,7 @@ class Peer(Logger, EventListener):
def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes):
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
assert chan.can_update_ctx(proposer=LOCAL), f"cannot send updates: {chan.short_channel_id}"
assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id)
self.received_htlcs_pending_removal.add((chan, htlc_id))
chan.settle_htlc(preimage, htlc_id)
@@ -2552,7 +2576,7 @@ class Peer(Logger, EventListener):
def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes):
self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
assert chan.can_update_ctx(proposer=LOCAL), f"cannot send updates: {chan.short_channel_id}"
self.received_htlcs_pending_removal.add((chan, htlc_id))
chan.fail_htlc(htlc_id)
self.send_message(
@@ -2564,7 +2588,7 @@ class Peer(Logger, EventListener):
def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure):
self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
assert chan.can_update_ctx(proposer=LOCAL), f"cannot send updates: {chan.short_channel_id}"
if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32):
raise Exception(f"unexpected reason when sending 'update_fail_malformed_htlc': {reason!r}")
self.received_htlcs_pending_removal.add((chan, htlc_id))
@@ -2576,12 +2600,13 @@ class Peer(Logger, EventListener):
sha256_of_onion=reason.data,
failure_code=reason.code)
def on_revoke_and_ack(self, chan: Channel, payload):
if chan.peer_state == PeerState.BAD:
return
def on_revoke_and_ack(self, chan: Channel, payload) -> None:
self.logger.info(f'on_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(REMOTE)}')
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received revoke_and_ack in unexpected {chan.peer_state=!r}")
if not chan.can_update_ctx(proposer=REMOTE):
self.logger.warning(
f"on_revoke_and_ack. dropping message. illegal action. "
f"chan={chan.get_id_for_log()}. {chan.get_state()=!r}. {chan.peer_state=!r}")
return
rev = RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"])
chan.receive_revocation(rev)
self.lnworker.save_channel(chan)
@@ -2597,8 +2622,11 @@ class Peer(Logger, EventListener):
await self.taskgroup.spawn(async_wrapper)
def on_update_fee(self, chan: Channel, payload):
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received update_fee in unexpected {chan.peer_state=!r}")
if not chan.can_update_ctx(proposer=REMOTE):
self.logger.warning(
f"on_update_fee. dropping message. illegal action. "
f"chan={chan.get_id_for_log()}. {chan.get_state()=!r}. {chan.peer_state=!r}")
return
feerate = payload["feerate_per_kw"]
chan.update_fee(feerate, False)
@@ -2606,7 +2634,7 @@ class Peer(Logger, EventListener):
"""
called when our fee estimates change
"""
if not chan.can_send_ctx_updates():
if not chan.can_update_ctx(proposer=LOCAL):
return
if chan.get_state() != ChannelState.OPEN:
return
@@ -2690,11 +2718,13 @@ class Peer(Logger, EventListener):
@non_blocking_msg_handler
async def on_shutdown(self, chan: Channel, payload):
# TODO: A receiving node: if it hasn't received a funding_signed (if it is a
# funder) or a funding_created (if it is a fundee):
# SHOULD send an error and fail the channel.
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received shutdown in unexpected {chan.peer_state=!r}")
if not self.can_send_shutdown(chan, proposer=REMOTE):
self.logger.warning(
f"on_shutdown. illegal action. "
f"chan={chan.get_id_for_log()}. {chan.get_state()=!r}. {chan.peer_state=!r}")
self.send_error(chan.channel_id, message="cannot process 'shutdown' in current channel state.")
their_scriptpubkey = payload['scriptpubkey']
their_upfront_scriptpubkey = chan.config[REMOTE].upfront_shutdown_script
# BOLT-02 check if they use the upfront shutdown script they advertised
@@ -2720,23 +2750,32 @@ class Peer(Logger, EventListener):
if chan_id in self.shutdown_received:
self.shutdown_received[chan_id].set_result(payload)
else:
chan = self.channels[chan_id]
await self.send_shutdown(chan)
txid = await self._shutdown(chan, payload, is_local=False)
self.logger.info(f'({chan.get_id_for_log()}) Channel closed by remote peer {txid}')
def can_send_shutdown(self, chan: Channel):
def can_send_shutdown(self, chan: Channel, *, proposer: HTLCOwner) -> bool:
if chan.get_state() >= ChannelState.CLOSED:
return False
if chan.get_state() >= ChannelState.OPENING:
return True
if chan.constraints.is_initiator and chan.channel_id in self.funding_created_sent:
return True
if not chan.constraints.is_initiator and chan.channel_id in self.funding_signed_sent:
return True
if proposer == LOCAL:
if chan.constraints.is_initiator and chan.channel_id in self.funding_created_sent:
return True
if not chan.constraints.is_initiator and chan.channel_id in self.funding_signed_sent:
return True
else: # proposer == REMOTE
# (from BOLT-02)
# A receiving node:
# - if it hasn't received a funding_signed (if it is a funder) or a funding_created (if it is a fundee):
# - SHOULD send an error and fail the channel.
# ^ that check is equivalent to `chan.get_state() < ChannelState.OPENING`, which is already checked.
pass
return False
async def send_shutdown(self, chan: Channel):
if not self.can_send_shutdown(chan):
raise Exception('cannot send shutdown')
if not self.can_send_shutdown(chan, proposer=LOCAL):
raise Exception(f"cannot send shutdown. chan={chan.get_id_for_log()}. {chan.get_state()=!r}")
if chan.config[LOCAL].upfront_shutdown_script:
scriptpubkey = chan.config[LOCAL].upfront_shutdown_script
else:
@@ -2984,7 +3023,7 @@ class Peer(Logger, EventListener):
self._htlc_switch_iterstart_event.clear()
self._maybe_cleanup_received_htlcs_pending_removal()
for chan_id, chan in self.channels.items():
if not chan.can_send_ctx_updates():
if not chan.can_update_ctx(proposer=LOCAL):
continue
self.maybe_send_commitment(chan)
done = set()