diff --git a/electrum/lnonion.py b/electrum/lnonion.py index a2f1a56a7..06b500f96 100644 --- a/electrum/lnonion.py +++ b/electrum/lnonion.py @@ -25,6 +25,7 @@ import io import hashlib +from functools import cached_property from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union, Mapping from enum import IntEnum from dataclasses import dataclass, field, replace @@ -157,6 +158,10 @@ class OnionPacket: version=b[0], ) + @cached_property + def onion_hash(self) -> bytes: + return sha256(self.to_bytes()) + def get_bolt04_onion_key(key_type: bytes, secret: bytes) -> bytes: if key_type not in (b'rho', b'mu', b'um', b'ammag', b'pad', b'blinded_node_id'): @@ -289,20 +294,19 @@ def calc_hops_data_for_payment( """ if len(route) > NUM_MAX_EDGES_IN_PAYMENT_PATH: raise PaymentFailure(f"too long route ({len(route)} edges)") - # payload that will be seen by the last hop: amt = amount_msat cltv_abs = final_cltv_abs + # payload that will be seen by the last hop: + # for multipart payments we need to tell the receiver about the total and + # partial amounts hop_payload = { "amt_to_forward": {"amt_to_forward": amt}, "outgoing_cltv_value": {"outgoing_cltv_value": cltv_abs}, - } - # for multipart payments we need to tell the receiver about the total and - # partial amounts - hop_payload["payment_data"] = { - "payment_secret": payment_secret, - "total_msat": total_msat, - "amount_msat": amt - } + "payment_data": { + "payment_secret": payment_secret, + "total_msat": total_msat, + "amount_msat": amt, + }} hops_data = [OnionHopsDataSingle(payload=hop_payload)] # payloads, backwards from last hop (but excluding the first edge): for edge_index in range(len(route) - 1, 0, -1): @@ -360,6 +364,36 @@ class ProcessedOnionPacket(NamedTuple): next_packet: OnionPacket trampoline_onion_packet: OnionPacket + @property + def amt_to_forward(self) -> Optional[int]: + k1 = k2 = 'amt_to_forward' + return self._get_from_payload(k1, k2, int) + + @property + def outgoing_cltv_value(self) -> Optional[int]: + k1 = k2 = 'outgoing_cltv_value' + return self._get_from_payload(k1, k2, int) + + @property + def next_chan_scid(self) -> Optional[ShortChannelID]: + k1 = k2 = 'short_channel_id' + return self._get_from_payload(k1, k2, ShortChannelID) + + @property + def total_msat(self) -> Optional[int]: + return self._get_from_payload('payment_data', 'total_msat', int) + + @property + def payment_secret(self) -> Optional[bytes]: + return self._get_from_payload('payment_data', 'payment_secret', bytes) + + def _get_from_payload(self, k1: str, k2: str, res_type: type): + try: + result = self.hop_data.payload[k1][k2] + return res_type(result) + except Exception: + return None + # TODO replay protection def process_onion_packet( diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index dc40eb040..faa32a9d3 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2093,29 +2093,23 @@ class Peer(Logger, EventListener): Perform checks that are invariant (results do not depend on height, network conditions, etc). May raise OnionRoutingFailure """ - try: - amt_to_forward = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"] - except Exception: + assert processed_onion.are_we_final, processed_onion + if (amt_to_forward := processed_onion.amt_to_forward) is None: log_fail_reason(f"'amt_to_forward' missing from onion") raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') - - exc_incorrect_or_unknown_pd = OnionRoutingFailure( - code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, - data=amt_to_forward.to_bytes(8, byteorder="big")) # height will be added later - try: - cltv_abs_from_onion = processed_onion.hop_data.payload["outgoing_cltv_value"]["outgoing_cltv_value"] - except Exception: + if (cltv_abs_from_onion := processed_onion.outgoing_cltv_value) is None: log_fail_reason(f"'outgoing_cltv_value' missing from onion") raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') - if cltv_abs_from_onion > htlc.cltv_abs: log_fail_reason(f"cltv_abs_from_onion != htlc.cltv_abs") raise OnionRoutingFailure( code=OnionFailureCode.FINAL_INCORRECT_CLTV_EXPIRY, data=htlc.cltv_abs.to_bytes(4, byteorder="big")) - try: - total_msat = processed_onion.hop_data.payload["payment_data"]["total_msat"] # type: int - except Exception: + + exc_incorrect_or_unknown_pd = OnionRoutingFailure( + code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, + data=amt_to_forward.to_bytes(8, byteorder="big")) # height will be added later + if (total_msat := processed_onion.total_msat) is None: log_fail_reason(f"'total_msat' missing from onion") raise exc_incorrect_or_unknown_pd @@ -2132,9 +2126,7 @@ class Peer(Logger, EventListener): code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT, data=htlc.amount_msat.to_bytes(8, byteorder="big")) - try: - payment_secret_from_onion = processed_onion.hop_data.payload["payment_data"]["payment_secret"] # type: bytes - except Exception: + if (payment_secret_from_onion := processed_onion.payment_secret) is None: log_fail_reason(f"'payment_secret' missing from onion") raise exc_incorrect_or_unknown_pd diff --git a/electrum/lnworker.py b/electrum/lnworker.py index b0850c0e3..6b9b0731d 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -3456,18 +3456,11 @@ class LNWallet(LNWorker): chain = self.network.blockchain() if chain.is_tip_stale(): raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') - try: - _next_chan_scid = processed_onion.hop_data.payload["short_channel_id"]["short_channel_id"] # type: bytes - next_chan_scid = ShortChannelID(_next_chan_scid) - except Exception: + if (next_chan_scid := processed_onion.next_chan_scid) is None: raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') - try: - next_amount_msat_htlc = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"] - except Exception: + if (next_amount_msat_htlc := processed_onion.amt_to_forward) is None: raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') - try: - next_cltv_abs = processed_onion.hop_data.payload["outgoing_cltv_value"]["outgoing_cltv_value"] - except Exception: + if (next_cltv_abs := processed_onion.outgoing_cltv_value) is None: raise OnionRoutingFailure(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') next_chan = self.get_channel_by_short_id(next_chan_scid)