1
0

simplify maybe_fulfill_htlc

- move some checks in two helper methods:
    - invariant checks are performed in check_accepted_htlc
    - mpp checks are performed in check_mpp_is waiting
 - in order to avoid passing local_height to check_accepted_htlc,
   the height in the error message is added by create_onion_error.
This commit is contained in:
ThomasV
2024-11-19 11:00:50 +01:00
parent 5704276cbe
commit 81aed0f6c9
2 changed files with 78 additions and 42 deletions

View File

@@ -391,12 +391,16 @@ class OnionRoutingFailure(Exception):
def construct_onion_error( def construct_onion_error(
reason: OnionRoutingFailure, error: OnionRoutingFailure,
their_public_key: bytes, their_public_key: bytes,
our_onion_private_key: bytes, our_onion_private_key: bytes,
local_height: int
) -> bytes: ) -> bytes:
# add local height
if error.code == OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS:
error.data += local_height.to_bytes(4, byteorder="big")
# create payload # create payload
failure_msg = reason.to_bytes() failure_msg = error.to_bytes()
failure_len = len(failure_msg) failure_len = len(failure_msg)
pad_len = 256 - failure_len pad_len = 256 - failure_len
assert pad_len >= 0 assert pad_len >= 0

View File

@@ -2022,51 +2022,26 @@ class Peer(Logger, EventListener):
# also make us fail arbitrary HTLCs. # also make us fail arbitrary HTLCs.
return bool(is_our_payreq and self.lnworker.get_preimage(payment_hash)) return bool(is_our_payreq and self.lnworker.get_preimage(payment_hash))
def maybe_fulfill_htlc( def check_accepted_htlc(
self, *, self, *,
chan: Channel, chan: Channel,
htlc: UpdateAddHtlc, htlc: UpdateAddHtlc,
processed_onion: ProcessedOnionPacket, processed_onion: ProcessedOnionPacket,
onion_packet_bytes: bytes, log_fail_reason: Callable,
already_forwarded: bool = False, ):
) -> Tuple[Optional[bytes], Optional[Tuple[str, Callable[[], Awaitable[Optional[str]]]]]]:
""" """
Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded. Perform checks that are invariant (results do not depend on height, network conditions, etc).
Return (preimage, (payment_key, callback)) with at most a single element not None. May raise OnionRoutingFailure
""" """
if not processed_onion.are_we_final:
if not self.lnworker.enable_htlc_forwarding:
return None, None
# use the htlc key if we are forwarding
payment_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id)
callback = lambda: self.maybe_forward_htlc(
incoming_chan=chan,
htlc=htlc,
processed_onion=processed_onion)
return None, (payment_key, callback)
def log_fail_reason(reason: str):
self.logger.info(
f"maybe_fulfill_htlc. will FAIL HTLC: chan {chan.short_channel_id}. "
f"{reason}. htlc={str(htlc)}. onion_payload={processed_onion.hop_data.payload}")
try: try:
amt_to_forward = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"] amt_to_forward = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"]
except Exception: except Exception:
log_fail_reason(f"'amt_to_forward' missing from onion") log_fail_reason(f"'amt_to_forward' missing from onion")
raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00')
# Check that our blockchain tip is sufficiently recent so that we have an approx idea of the height.
# We should not release the preimage for an HTLC that its sender could already time out as
# then they might try to force-close and it becomes a race.
chain = self.network.blockchain()
if chain.is_tip_stale() and not already_forwarded:
log_fail_reason(f"our chain tip is stale")
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
local_height = chain.height()
exc_incorrect_or_unknown_pd = OnionRoutingFailure( exc_incorrect_or_unknown_pd = OnionRoutingFailure(
code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS,
data=amt_to_forward.to_bytes(8, byteorder="big") + local_height.to_bytes(4, byteorder="big")) data=amt_to_forward.to_bytes(8, byteorder="big")) # height will be added later
try: try:
cltv_abs_from_onion = processed_onion.hop_data.payload["outgoing_cltv_value"]["outgoing_cltv_value"] cltv_abs_from_onion = processed_onion.hop_data.payload["outgoing_cltv_value"]["outgoing_cltv_value"]
except Exception: except Exception:
@@ -2103,19 +2078,18 @@ class Peer(Logger, EventListener):
log_fail_reason(f"'payment_secret' missing from onion") log_fail_reason(f"'payment_secret' missing from onion")
raise exc_incorrect_or_unknown_pd raise exc_incorrect_or_unknown_pd
# payment key for final onions return payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd
payment_hash = htlc.payment_hash
payment_key = (payment_hash + payment_secret_from_onion).hex()
def check_mpp_is_waiting(self, *, payment_secret, short_channel_id, htlc, expected_msat, exc_incorrect_or_unknown_pd, log_fail_reason) -> bool:
from .lnworker import RecvMPPResolution from .lnworker import RecvMPPResolution
mpp_resolution = self.lnworker.check_mpp_status( mpp_resolution = self.lnworker.check_mpp_status(
payment_secret=payment_secret_from_onion, payment_secret=payment_secret,
short_channel_id=chan.get_scid_or_local_alias(), short_channel_id=short_channel_id,
htlc=htlc, htlc=htlc,
expected_msat=total_msat, expected_msat=expected_msat,
) )
if mpp_resolution == RecvMPPResolution.WAITING: if mpp_resolution == RecvMPPResolution.WAITING:
return None, None return True
elif mpp_resolution == RecvMPPResolution.EXPIRED: elif mpp_resolution == RecvMPPResolution.EXPIRED:
log_fail_reason(f"MPP_TIMEOUT") log_fail_reason(f"MPP_TIMEOUT")
raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
@@ -2123,10 +2097,68 @@ class Peer(Logger, EventListener):
log_fail_reason(f"mpp_resolution is FAILED") log_fail_reason(f"mpp_resolution is FAILED")
raise exc_incorrect_or_unknown_pd raise exc_incorrect_or_unknown_pd
elif mpp_resolution == RecvMPPResolution.ACCEPTED: elif mpp_resolution == RecvMPPResolution.ACCEPTED:
pass # continue return False
else: else:
raise Exception(f"unexpected {mpp_resolution=}") raise Exception(f"unexpected {mpp_resolution=}")
def maybe_fulfill_htlc(
self, *,
chan: Channel,
htlc: UpdateAddHtlc,
processed_onion: ProcessedOnionPacket,
onion_packet_bytes: bytes,
already_forwarded: bool = False,
) -> Tuple[Optional[bytes], Optional[Tuple[str, Callable[[], Awaitable[Optional[str]]]]]]:
"""
Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded.
Return (preimage, (payment_key, callback)) with at most a single element not None.
"""
if not processed_onion.are_we_final:
if not self.lnworker.enable_htlc_forwarding:
return None, None
# use the htlc key if we are forwarding
payment_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id)
callback = lambda: self.maybe_forward_htlc(
incoming_chan=chan,
htlc=htlc,
processed_onion=processed_onion)
return None, (payment_key, callback)
def log_fail_reason(reason: str):
self.logger.info(
f"maybe_fulfill_htlc. will FAIL HTLC: chan {chan.short_channel_id}. "
f"{reason}. htlc={str(htlc)}. onion_payload={processed_onion.hop_data.payload}")
chain = self.network.blockchain()
# Check that our blockchain tip is sufficiently recent so that we have an approx idea of the height.
# We should not release the preimage for an HTLC that its sender could already time out as
# then they might try to force-close and it becomes a race.
if chain.is_tip_stale() and not already_forwarded:
log_fail_reason(f"our chain tip is stale")
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
local_height = chain.height()
# parse parameters and perform checks that are invariant
payment_secret_from_onion, total_msat, channel_opening_fee, exc_incorrect_or_unknown_pd = self.check_accepted_htlc(
chan=chan,
htlc=htlc,
processed_onion=processed_onion,
log_fail_reason=log_fail_reason)
# payment key for final onions
payment_hash = htlc.payment_hash
payment_key = (payment_hash + payment_secret_from_onion).hex()
if self.check_mpp_is_waiting(
payment_secret=payment_secret_from_onion,
short_channel_id=chan.get_scid_or_local_alias(),
htlc=htlc,
expected_msat=total_msat,
exc_incorrect_or_unknown_pd=exc_incorrect_or_unknown_pd,
log_fail_reason=log_fail_reason,
):
return None, None
# TODO check against actual min_final_cltv_expiry_delta from invoice (and give 2-3 blocks of leeway?) # TODO check against actual min_final_cltv_expiry_delta from invoice (and give 2-3 blocks of leeway?)
if local_height + MIN_FINAL_CLTV_DELTA_ACCEPTED > htlc.cltv_abs and not already_forwarded: if local_height + MIN_FINAL_CLTV_DELTA_ACCEPTED > htlc.cltv_abs and not already_forwarded:
log_fail_reason(f"htlc.cltv_abs is unreasonably close") log_fail_reason(f"htlc.cltv_abs is unreasonably close")
@@ -2665,7 +2697,7 @@ class Peer(Logger, EventListener):
assert forwarding_key is None assert forwarding_key is None
unfulfilled[htlc_id] = onion_packet_hex, _forwarding_key unfulfilled[htlc_id] = onion_packet_hex, _forwarding_key
except OnionRoutingFailure as e: except OnionRoutingFailure as e:
error_bytes = construct_onion_error(e, onion_packet.public_key, our_onion_private_key=self.privkey) error_bytes = construct_onion_error(e, onion_packet.public_key, self.privkey, self.network.get_local_height())
if error_bytes: if error_bytes:
error_bytes = obfuscate_onion_error(error_bytes, onion_packet.public_key, our_onion_private_key=self.privkey) error_bytes = obfuscate_onion_error(error_bytes, onion_packet.public_key, our_onion_private_key=self.privkey)