1
0

lnonion: immutable OnionPacket and OnionHopsDataSingle

Make OnionHopsDataSingle and OnionPacket immutable for safer caching and
handling.

# Conflicts:
#	electrum/onion_message.py
This commit is contained in:
f321x
2025-10-13 14:24:42 +02:00
parent 1ad6607405
commit 936e7fd1c2
8 changed files with 148 additions and 92 deletions

View File

@@ -27,6 +27,8 @@ import io
import hashlib
from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union
from enum import IntEnum
from dataclasses import dataclass, field, replace
from types import MappingProxyType
import electrum_ecc as ecc
@@ -53,18 +55,22 @@ class InvalidOnionPubkey(Exception): pass
class InvalidPayloadSize(Exception): pass
class OnionHopsDataSingle: # called HopData in lnd
@dataclass(frozen=True, kw_only=True)
class OnionHopsDataSingle:
payload: MappingProxyType = field(default_factory=lambda: MappingProxyType({}))
hmac: Optional[bytes] = None
tlv_stream_name: str = 'payload'
blind_fields: MappingProxyType = field(default_factory=lambda: MappingProxyType({}))
_raw_bytes_payload: Optional[bytes] = None
def __init__(self, *, payload: dict = None, tlv_stream_name: str = 'payload', blind_fields: dict = None):
if payload is None:
payload = {}
self.payload = payload
self.hmac = None
self.tlv_stream_name = tlv_stream_name
if blind_fields is None:
blind_fields = {}
self.blind_fields = blind_fields
self._raw_bytes_payload = None # used in unit tests
def __post_init__(self):
# make all fields immutable recursively
object.__setattr__(self, 'payload', util.make_object_immutable(self.payload))
object.__setattr__(self, 'blind_fields', util.make_object_immutable(self.blind_fields))
assert isinstance(self.payload, MappingProxyType)
assert isinstance(self.blind_fields, MappingProxyType)
assert isinstance(self.tlv_stream_name, str)
assert (isinstance(self.hmac, bytes) and len(self.hmac) == PER_HOP_HMAC_SIZE) or self.hmac is None
def to_bytes(self) -> bytes:
hmac_ = self.hmac if self.hmac is not None else bytes(PER_HOP_HMAC_SIZE)
@@ -101,32 +107,35 @@ class OnionHopsDataSingle: # called HopData in lnd
hop_payload = fd.read(hop_payload_length)
if hop_payload_length != len(hop_payload):
raise Exception(f"unexpected EOF")
ret = OnionHopsDataSingle(tlv_stream_name=tlv_stream_name)
ret.payload = OnionWireSerializer.read_tlv_stream(fd=io.BytesIO(hop_payload),
tlv_stream_name=tlv_stream_name)
ret.hmac = fd.read(PER_HOP_HMAC_SIZE)
assert len(ret.hmac) == PER_HOP_HMAC_SIZE
payload = OnionWireSerializer.read_tlv_stream(fd=io.BytesIO(hop_payload),
tlv_stream_name=tlv_stream_name)
ret = OnionHopsDataSingle(
tlv_stream_name=tlv_stream_name,
payload=MappingProxyType(payload),
hmac=fd.read(PER_HOP_HMAC_SIZE)
)
return ret
def __repr__(self):
return f"<OnionHopsDataSingle. payload={self.payload}. hmac={self.hmac}>"
return f"<OnionHopsDataSingle. {self.payload=}. {self.hmac=}>"
@dataclass(frozen=True, kw_only=True)
class OnionPacket:
public_key: bytes
hops_data: bytes # also called RoutingInfo in bolt-04
hmac: bytes
version: int = 0
# for debugging our own onions:
_debug_hops_data: Optional[Sequence[OnionHopsDataSingle]] = None
_debug_route: Optional['LNPaymentRoute'] = None
def __init__(self, *, public_key: bytes, hops_data: bytes, hmac: bytes, version: int = 0):
assert len(public_key) == 33
assert len(hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]
assert len(hmac) == PER_HOP_HMAC_SIZE
self.version = version
self.public_key = public_key
self.hops_data = hops_data # also called RoutingInfo in bolt-04
self.hmac = hmac
if not ecc.ECPubkey.is_pubkey_bytes(public_key):
def __post_init__(self):
assert len(self.public_key) == 33
assert len(self.hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]
assert len(self.hmac) == PER_HOP_HMAC_SIZE
if not ecc.ECPubkey.is_pubkey_bytes(self.public_key):
raise InvalidOnionPubkey()
# for debugging our own onions:
self._debug_hops_data = None # type: Optional[Sequence[OnionHopsDataSingle]]
self._debug_route = None # type: Optional[LNPaymentRoute]
def to_bytes(self) -> bytes:
ret = bytes([self.version])
@@ -138,7 +147,7 @@ class OnionPacket:
return ret
@classmethod
def from_bytes(cls, b: bytes):
def from_bytes(cls, b: bytes) -> 'OnionPacket':
if len(b) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]:
raise Exception('unexpected length {}'.format(len(b)))
return OnionPacket(
@@ -187,7 +196,7 @@ def get_blinded_node_id(node_id: bytes, shared_secret: bytes):
def new_onion_packet(
payment_path_pubkeys: Sequence[bytes],
session_key: bytes,
hops_data: Sequence[OnionHopsDataSingle],
hops_data: List[OnionHopsDataSingle],
*,
associated_data: bytes = b'',
trampoline: bool = False,
@@ -226,7 +235,7 @@ def new_onion_packet(
for i in range(num_hops-1, -1, -1):
rho_key = get_bolt04_onion_key(b'rho', hop_shared_secrets[i])
mu_key = get_bolt04_onion_key(b'mu', hop_shared_secrets[i])
hops_data[i].hmac = next_hmac
hops_data[i] = replace(hops_data[i], hmac=next_hmac)
stream_bytes = generate_cipher_stream(rho_key, data_size)
hop_data_bytes = hops_data[i].to_bytes()
mix_header = mix_header[:-len(hop_data_bytes)]
@@ -294,7 +303,7 @@ def calc_hops_data_for_payment(
"total_msat": total_msat,
"amount_msat": amt
}
hops_data = [OnionHopsDataSingle(payload=hop_payload)]
hops_data = [OnionHopsDataSingle(payload=MappingProxyType(hop_payload))]
# payloads, backwards from last hop (but excluding the first edge):
for edge_index in range(len(route) - 1, 0, -1):
route_edge = route[edge_index]
@@ -304,7 +313,7 @@ def calc_hops_data_for_payment(
"short_channel_id": {"short_channel_id": route_edge.short_channel_id},
}
hops_data.append(
OnionHopsDataSingle(payload=hop_payload))
OnionHopsDataSingle(payload=MappingProxyType(hop_payload)))
amt += route_edge.fee_for_edge(amt)
cltv_abs += route_edge.cltv_delta
hops_data.reverse()