1
0

add onion message support

This commit is contained in:
Sander van Grieken
2024-11-22 14:46:21 +01:00
parent 12ffbfc29e
commit 7b4180202a
10 changed files with 1387 additions and 33 deletions

View File

@@ -30,7 +30,7 @@ from enum import IntEnum
import electrum_ecc as ecc
from .crypto import sha256, hmac_oneshot, chacha20_encrypt, get_ecdh
from .crypto import sha256, hmac_oneshot, chacha20_encrypt, get_ecdh, chacha20_poly1305_encrypt, chacha20_poly1305_decrypt
from .util import profiler, xor_bytes, bfh
from .lnutil import (PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH,
NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, OnionFailureCodeMetaFlag)
@@ -44,20 +44,25 @@ if TYPE_CHECKING:
HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04
TRAMPOLINE_HOPS_DATA_SIZE = 400
PER_HOP_HMAC_SIZE = 32
ONION_MESSAGE_LARGE_SIZE = 32768
class UnsupportedOnionPacketVersion(Exception): pass
class InvalidOnionMac(Exception): pass
class InvalidOnionPubkey(Exception): pass
class InvalidPayloadSize(Exception): pass
class OnionHopsDataSingle: # called HopData in lnd
def __init__(self, *, payload: dict = None):
def __init__(self, *, payload: dict = None, tlv_stream_name: str = 'payload', blind_fields: dict = None):
if payload is None:
payload = {}
self.payload = payload
self.hmac = None
self.tlv_stream_name = tlv_stream_name
if blind_fields is None:
blind_fields = {}
self.blind_fields = blind_fields
self._raw_bytes_payload = None # used in unit tests
def to_bytes(self) -> bytes:
@@ -69,7 +74,7 @@ class OnionHopsDataSingle: # called HopData in lnd
# adding TLV payload. note: legacy hop data format no longer supported.
payload_fd = io.BytesIO()
OnionWireSerializer.write_tlv_stream(fd=payload_fd,
tlv_stream_name="payload",
tlv_stream_name=self.tlv_stream_name,
**self.payload)
payload_bytes = payload_fd.getvalue()
with io.BytesIO() as fd:
@@ -79,7 +84,7 @@ class OnionHopsDataSingle: # called HopData in lnd
return fd.getvalue()
@classmethod
def from_fd(cls, fd: io.BytesIO) -> 'OnionHopsDataSingle':
def from_fd(cls, fd: io.BytesIO, *, tlv_stream_name: str = 'payload') -> 'OnionHopsDataSingle':
first_byte = fd.read(1)
if len(first_byte) == 0:
raise Exception(f"unexpected EOF")
@@ -95,9 +100,9 @@ class OnionHopsDataSingle: # called HopData in lnd
hop_payload = fd.read(hop_payload_length)
if hop_payload_length != len(hop_payload):
raise Exception(f"unexpected EOF")
ret = OnionHopsDataSingle()
ret = OnionHopsDataSingle(tlv_stream_name=tlv_stream_name)
ret.payload = OnionWireSerializer.read_tlv_stream(fd=io.BytesIO(hop_payload),
tlv_stream_name="payload")
tlv_stream_name=tlv_stream_name)
ret.hmac = fd.read(PER_HOP_HMAC_SIZE)
assert len(ret.hmac) == PER_HOP_HMAC_SIZE
return ret
@@ -110,7 +115,7 @@ class OnionPacket:
def __init__(self, public_key: bytes, hops_data: bytes, hmac: bytes):
assert len(public_key) == 33
assert len(hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE]
assert len(hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]
assert len(hmac) == PER_HOP_HMAC_SIZE
self.version = 0
self.public_key = public_key
@@ -127,13 +132,13 @@ class OnionPacket:
ret += self.public_key
ret += self.hops_data
ret += self.hmac
if len(ret) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE]:
if len(ret) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]:
raise Exception('unexpected length {}'.format(len(ret)))
return ret
@classmethod
def from_bytes(cls, b: bytes):
if len(b) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE]:
if len(b) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]:
raise Exception('unexpected length {}'.format(len(b)))
version = b[0]
if version != 0:
@@ -146,27 +151,38 @@ class OnionPacket:
def get_bolt04_onion_key(key_type: bytes, secret: bytes) -> bytes:
if key_type not in (b'rho', b'mu', b'um', b'ammag', b'pad'):
if key_type not in (b'rho', b'mu', b'um', b'ammag', b'pad', b'blinded_node_id'):
raise Exception('invalid key_type {}'.format(key_type))
key = hmac_oneshot(key_type, msg=secret, digest=hashlib.sha256)
return key
def get_shared_secrets_along_route(payment_path_pubkeys: Sequence[bytes],
session_key: bytes) -> Sequence[bytes]:
session_key: bytes) -> Tuple[Sequence[bytes], Sequence[bytes]]:
num_hops = len(payment_path_pubkeys)
hop_shared_secrets = num_hops * [b'']
hop_blinded_node_ids = num_hops * [b'']
ephemeral_key = session_key
# compute shared key for each hop
for i in range(0, num_hops):
hop_shared_secrets[i] = get_ecdh(ephemeral_key, payment_path_pubkeys[i])
hop_blinded_node_ids[i] = get_blinded_node_id(payment_path_pubkeys[i], hop_shared_secrets[i])
ephemeral_pubkey = ecc.ECPrivkey(ephemeral_key).get_public_key_bytes()
blinding_factor = sha256(ephemeral_pubkey + hop_shared_secrets[i])
blinding_factor_int = int.from_bytes(blinding_factor, byteorder="big")
ephemeral_key_int = int.from_bytes(ephemeral_key, byteorder="big")
ephemeral_key_int = ephemeral_key_int * blinding_factor_int % ecc.CURVE_ORDER
ephemeral_key = ephemeral_key_int.to_bytes(32, byteorder="big")
return hop_shared_secrets
return hop_shared_secrets, hop_blinded_node_ids
def get_blinded_node_id(node_id: bytes, shared_secret: bytes):
# blinded node id
# B(i) = HMAC256("blinded_node_id", ss(i)) * N(i)
ss_bni_hmac = get_bolt04_onion_key(b'blinded_node_id', shared_secret)
ss_bni_hmac_int = int.from_bytes(ss_bni_hmac, byteorder="big")
blinded_node_id = ecc.ECPubkey(node_id) * ss_bni_hmac_int
return blinded_node_id.get_public_key_bytes()
def new_onion_packet(
@@ -174,14 +190,31 @@ def new_onion_packet(
session_key: bytes,
hops_data: Sequence[OnionHopsDataSingle],
*,
associated_data: bytes,
associated_data: bytes = b'',
trampoline: bool = False,
onion_message: bool = False
) -> OnionPacket:
num_hops = len(payment_path_pubkeys)
assert num_hops == len(hops_data)
hop_shared_secrets = get_shared_secrets_along_route(payment_path_pubkeys, session_key)
hop_shared_secrets, _ = get_shared_secrets_along_route(payment_path_pubkeys, session_key)
payload_size = 0
for i in range(num_hops):
# FIXME: serializing here and again below. cache bytes in OnionHopsDataSingle? _raw_bytes_payload?
payload_size += PER_HOP_HMAC_SIZE + len(hops_data[i].to_bytes())
if trampoline:
data_size = TRAMPOLINE_HOPS_DATA_SIZE
elif onion_message:
if payload_size <= HOPS_DATA_SIZE:
data_size = HOPS_DATA_SIZE
else:
data_size = ONION_MESSAGE_LARGE_SIZE
else:
data_size = HOPS_DATA_SIZE
if payload_size > data_size:
raise InvalidPayloadSize(f'payload too big for onion packet (max={data_size}, required={payload_size})')
data_size = TRAMPOLINE_HOPS_DATA_SIZE if trampoline else HOPS_DATA_SIZE
filler = _generate_filler(b'rho', hops_data, hop_shared_secrets, data_size)
next_hmac = bytes(PER_HOP_HMAC_SIZE)
@@ -211,6 +244,30 @@ def new_onion_packet(
hmac=next_hmac)
def encrypt_onionmsg_data_tlv(*, shared_secret, **kwargs):
rho_key = get_bolt04_onion_key(b'rho', shared_secret)
with io.BytesIO() as encrypted_data_tlv_fd:
OnionWireSerializer.write_tlv_stream(
fd=encrypted_data_tlv_fd,
tlv_stream_name='encrypted_data_tlv',
**kwargs)
encrypted_data_tlv_bytes = encrypted_data_tlv_fd.getvalue()
encrypted_recipient_data = chacha20_poly1305_encrypt(
key=rho_key, nonce=bytes(12),
data=encrypted_data_tlv_bytes)
return encrypted_recipient_data
def decrypt_onionmsg_data_tlv(*, shared_secret: bytes, encrypted_recipient_data: bytes) -> dict:
rho_key = get_bolt04_onion_key(b'rho', shared_secret)
recipient_data_bytes = chacha20_poly1305_decrypt(key=rho_key, nonce=bytes(12), data=encrypted_recipient_data)
with io.BytesIO(recipient_data_bytes) as fd:
recipient_data = OnionWireSerializer.read_tlv_stream(fd=fd, tlv_stream_name='encrypted_data_tlv')
return recipient_data
def calc_hops_data_for_payment(
route: 'LNPaymentRoute',
amount_msat: int, # that final recipient receives
@@ -299,9 +356,11 @@ class ProcessedOnionPacket(NamedTuple):
# TODO replay protection
def process_onion_packet(
onion_packet: OnionPacket,
associated_data: bytes,
our_onion_private_key: bytes,
is_trampoline=False) -> ProcessedOnionPacket:
*,
associated_data: bytes = b'',
is_trampoline=False,
tlv_stream_name='payload') -> ProcessedOnionPacket:
if not ecc.ECPubkey.is_pubkey_bytes(onion_packet.public_key):
raise InvalidOnionPubkey()
shared_secret = get_ecdh(our_onion_private_key, onion_packet.public_key)
@@ -319,7 +378,7 @@ def process_onion_packet(
padded_header = onion_packet.hops_data + bytes(data_size)
next_hops_data = xor_bytes(padded_header, stream_bytes)
next_hops_data_fd = io.BytesIO(next_hops_data)
hop_data = OnionHopsDataSingle.from_fd(next_hops_data_fd)
hop_data = OnionHopsDataSingle.from_fd(next_hops_data_fd, tlv_stream_name=tlv_stream_name)
# trampoline
trampoline_onion_packet = hop_data.payload.get('trampoline_onion_packet')
if trampoline_onion_packet:
@@ -427,7 +486,7 @@ def _decode_onion_error(error_packet: bytes, payment_path_pubkeys: Sequence[byte
session_key: bytes) -> Tuple[bytes, int]:
"""Returns the decoded error bytes, and the index of the sender of the error."""
num_hops = len(payment_path_pubkeys)
hop_shared_secrets = get_shared_secrets_along_route(payment_path_pubkeys, session_key)
hop_shared_secrets, _ = get_shared_secrets_along_route(payment_path_pubkeys, session_key)
for i in range(num_hops):
ammag_key = get_bolt04_onion_key(b'ammag', hop_shared_secrets[i])
um_key = get_bolt04_onion_key(b'um', hop_shared_secrets[i])