lnonion: immutable OnionPacket and OnionHopsDataSingle
Make OnionHopsDataSingle and OnionPacket immutable for safer caching and handling. # Conflicts: # electrum/onion_message.py
This commit is contained in:
@@ -27,6 +27,8 @@ import io
|
||||
import hashlib
|
||||
from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union
|
||||
from enum import IntEnum
|
||||
from dataclasses import dataclass, field, replace
|
||||
from types import MappingProxyType
|
||||
|
||||
import electrum_ecc as ecc
|
||||
|
||||
@@ -53,18 +55,22 @@ class InvalidOnionPubkey(Exception): pass
|
||||
class InvalidPayloadSize(Exception): pass
|
||||
|
||||
|
||||
class OnionHopsDataSingle: # called HopData in lnd
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OnionHopsDataSingle:
|
||||
payload: MappingProxyType = field(default_factory=lambda: MappingProxyType({}))
|
||||
hmac: Optional[bytes] = None
|
||||
tlv_stream_name: str = 'payload'
|
||||
blind_fields: MappingProxyType = field(default_factory=lambda: MappingProxyType({}))
|
||||
_raw_bytes_payload: Optional[bytes] = None
|
||||
|
||||
def __init__(self, *, payload: dict = None, tlv_stream_name: str = 'payload', blind_fields: dict = None):
|
||||
if payload is None:
|
||||
payload = {}
|
||||
self.payload = payload
|
||||
self.hmac = None
|
||||
self.tlv_stream_name = tlv_stream_name
|
||||
if blind_fields is None:
|
||||
blind_fields = {}
|
||||
self.blind_fields = blind_fields
|
||||
self._raw_bytes_payload = None # used in unit tests
|
||||
def __post_init__(self):
|
||||
# make all fields immutable recursively
|
||||
object.__setattr__(self, 'payload', util.make_object_immutable(self.payload))
|
||||
object.__setattr__(self, 'blind_fields', util.make_object_immutable(self.blind_fields))
|
||||
assert isinstance(self.payload, MappingProxyType)
|
||||
assert isinstance(self.blind_fields, MappingProxyType)
|
||||
assert isinstance(self.tlv_stream_name, str)
|
||||
assert (isinstance(self.hmac, bytes) and len(self.hmac) == PER_HOP_HMAC_SIZE) or self.hmac is None
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
hmac_ = self.hmac if self.hmac is not None else bytes(PER_HOP_HMAC_SIZE)
|
||||
@@ -101,32 +107,35 @@ class OnionHopsDataSingle: # called HopData in lnd
|
||||
hop_payload = fd.read(hop_payload_length)
|
||||
if hop_payload_length != len(hop_payload):
|
||||
raise Exception(f"unexpected EOF")
|
||||
ret = OnionHopsDataSingle(tlv_stream_name=tlv_stream_name)
|
||||
ret.payload = OnionWireSerializer.read_tlv_stream(fd=io.BytesIO(hop_payload),
|
||||
tlv_stream_name=tlv_stream_name)
|
||||
ret.hmac = fd.read(PER_HOP_HMAC_SIZE)
|
||||
assert len(ret.hmac) == PER_HOP_HMAC_SIZE
|
||||
payload = OnionWireSerializer.read_tlv_stream(fd=io.BytesIO(hop_payload),
|
||||
tlv_stream_name=tlv_stream_name)
|
||||
ret = OnionHopsDataSingle(
|
||||
tlv_stream_name=tlv_stream_name,
|
||||
payload=MappingProxyType(payload),
|
||||
hmac=fd.read(PER_HOP_HMAC_SIZE)
|
||||
)
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OnionHopsDataSingle. payload={self.payload}. hmac={self.hmac}>"
|
||||
return f"<OnionHopsDataSingle. {self.payload=}. {self.hmac=}>"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class OnionPacket:
|
||||
public_key: bytes
|
||||
hops_data: bytes # also called RoutingInfo in bolt-04
|
||||
hmac: bytes
|
||||
version: int = 0
|
||||
# for debugging our own onions:
|
||||
_debug_hops_data: Optional[Sequence[OnionHopsDataSingle]] = None
|
||||
_debug_route: Optional['LNPaymentRoute'] = None
|
||||
|
||||
def __init__(self, *, public_key: bytes, hops_data: bytes, hmac: bytes, version: int = 0):
|
||||
assert len(public_key) == 33
|
||||
assert len(hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]
|
||||
assert len(hmac) == PER_HOP_HMAC_SIZE
|
||||
self.version = version
|
||||
self.public_key = public_key
|
||||
self.hops_data = hops_data # also called RoutingInfo in bolt-04
|
||||
self.hmac = hmac
|
||||
if not ecc.ECPubkey.is_pubkey_bytes(public_key):
|
||||
def __post_init__(self):
|
||||
assert len(self.public_key) == 33
|
||||
assert len(self.hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]
|
||||
assert len(self.hmac) == PER_HOP_HMAC_SIZE
|
||||
if not ecc.ECPubkey.is_pubkey_bytes(self.public_key):
|
||||
raise InvalidOnionPubkey()
|
||||
# for debugging our own onions:
|
||||
self._debug_hops_data = None # type: Optional[Sequence[OnionHopsDataSingle]]
|
||||
self._debug_route = None # type: Optional[LNPaymentRoute]
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
ret = bytes([self.version])
|
||||
@@ -138,7 +147,7 @@ class OnionPacket:
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, b: bytes):
|
||||
def from_bytes(cls, b: bytes) -> 'OnionPacket':
|
||||
if len(b) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]:
|
||||
raise Exception('unexpected length {}'.format(len(b)))
|
||||
return OnionPacket(
|
||||
@@ -187,7 +196,7 @@ def get_blinded_node_id(node_id: bytes, shared_secret: bytes):
|
||||
def new_onion_packet(
|
||||
payment_path_pubkeys: Sequence[bytes],
|
||||
session_key: bytes,
|
||||
hops_data: Sequence[OnionHopsDataSingle],
|
||||
hops_data: List[OnionHopsDataSingle],
|
||||
*,
|
||||
associated_data: bytes = b'',
|
||||
trampoline: bool = False,
|
||||
@@ -226,7 +235,7 @@ def new_onion_packet(
|
||||
for i in range(num_hops-1, -1, -1):
|
||||
rho_key = get_bolt04_onion_key(b'rho', hop_shared_secrets[i])
|
||||
mu_key = get_bolt04_onion_key(b'mu', hop_shared_secrets[i])
|
||||
hops_data[i].hmac = next_hmac
|
||||
hops_data[i] = replace(hops_data[i], hmac=next_hmac)
|
||||
stream_bytes = generate_cipher_stream(rho_key, data_size)
|
||||
hop_data_bytes = hops_data[i].to_bytes()
|
||||
mix_header = mix_header[:-len(hop_data_bytes)]
|
||||
@@ -294,7 +303,7 @@ def calc_hops_data_for_payment(
|
||||
"total_msat": total_msat,
|
||||
"amount_msat": amt
|
||||
}
|
||||
hops_data = [OnionHopsDataSingle(payload=hop_payload)]
|
||||
hops_data = [OnionHopsDataSingle(payload=MappingProxyType(hop_payload))]
|
||||
# payloads, backwards from last hop (but excluding the first edge):
|
||||
for edge_index in range(len(route) - 1, 0, -1):
|
||||
route_edge = route[edge_index]
|
||||
@@ -304,7 +313,7 @@ def calc_hops_data_for_payment(
|
||||
"short_channel_id": {"short_channel_id": route_edge.short_channel_id},
|
||||
}
|
||||
hops_data.append(
|
||||
OnionHopsDataSingle(payload=hop_payload))
|
||||
OnionHopsDataSingle(payload=MappingProxyType(hop_payload)))
|
||||
amt += route_edge.fee_for_edge(amt)
|
||||
cltv_abs += route_edge.cltv_delta
|
||||
hops_data.reverse()
|
||||
|
||||
Reference in New Issue
Block a user