add onion message support
This commit is contained in:
@@ -30,7 +30,7 @@ from enum import IntEnum
|
||||
|
||||
import electrum_ecc as ecc
|
||||
|
||||
from .crypto import sha256, hmac_oneshot, chacha20_encrypt, get_ecdh
|
||||
from .crypto import sha256, hmac_oneshot, chacha20_encrypt, get_ecdh, chacha20_poly1305_encrypt, chacha20_poly1305_decrypt
|
||||
from .util import profiler, xor_bytes, bfh
|
||||
from .lnutil import (PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH,
|
||||
NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, OnionFailureCodeMetaFlag)
|
||||
@@ -44,20 +44,25 @@ if TYPE_CHECKING:
|
||||
HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04
|
||||
TRAMPOLINE_HOPS_DATA_SIZE = 400
|
||||
PER_HOP_HMAC_SIZE = 32
|
||||
|
||||
ONION_MESSAGE_LARGE_SIZE = 32768
|
||||
|
||||
class UnsupportedOnionPacketVersion(Exception): pass
|
||||
class InvalidOnionMac(Exception): pass
|
||||
class InvalidOnionPubkey(Exception): pass
|
||||
class InvalidPayloadSize(Exception): pass
|
||||
|
||||
|
||||
class OnionHopsDataSingle: # called HopData in lnd
|
||||
|
||||
def __init__(self, *, payload: dict = None):
|
||||
def __init__(self, *, payload: dict = None, tlv_stream_name: str = 'payload', blind_fields: dict = None):
|
||||
if payload is None:
|
||||
payload = {}
|
||||
self.payload = payload
|
||||
self.hmac = None
|
||||
self.tlv_stream_name = tlv_stream_name
|
||||
if blind_fields is None:
|
||||
blind_fields = {}
|
||||
self.blind_fields = blind_fields
|
||||
self._raw_bytes_payload = None # used in unit tests
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
@@ -69,7 +74,7 @@ class OnionHopsDataSingle: # called HopData in lnd
|
||||
# adding TLV payload. note: legacy hop data format no longer supported.
|
||||
payload_fd = io.BytesIO()
|
||||
OnionWireSerializer.write_tlv_stream(fd=payload_fd,
|
||||
tlv_stream_name="payload",
|
||||
tlv_stream_name=self.tlv_stream_name,
|
||||
**self.payload)
|
||||
payload_bytes = payload_fd.getvalue()
|
||||
with io.BytesIO() as fd:
|
||||
@@ -79,7 +84,7 @@ class OnionHopsDataSingle: # called HopData in lnd
|
||||
return fd.getvalue()
|
||||
|
||||
@classmethod
|
||||
def from_fd(cls, fd: io.BytesIO) -> 'OnionHopsDataSingle':
|
||||
def from_fd(cls, fd: io.BytesIO, *, tlv_stream_name: str = 'payload') -> 'OnionHopsDataSingle':
|
||||
first_byte = fd.read(1)
|
||||
if len(first_byte) == 0:
|
||||
raise Exception(f"unexpected EOF")
|
||||
@@ -95,9 +100,9 @@ class OnionHopsDataSingle: # called HopData in lnd
|
||||
hop_payload = fd.read(hop_payload_length)
|
||||
if hop_payload_length != len(hop_payload):
|
||||
raise Exception(f"unexpected EOF")
|
||||
ret = OnionHopsDataSingle()
|
||||
ret = OnionHopsDataSingle(tlv_stream_name=tlv_stream_name)
|
||||
ret.payload = OnionWireSerializer.read_tlv_stream(fd=io.BytesIO(hop_payload),
|
||||
tlv_stream_name="payload")
|
||||
tlv_stream_name=tlv_stream_name)
|
||||
ret.hmac = fd.read(PER_HOP_HMAC_SIZE)
|
||||
assert len(ret.hmac) == PER_HOP_HMAC_SIZE
|
||||
return ret
|
||||
@@ -110,7 +115,7 @@ class OnionPacket:
|
||||
|
||||
def __init__(self, public_key: bytes, hops_data: bytes, hmac: bytes):
|
||||
assert len(public_key) == 33
|
||||
assert len(hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE]
|
||||
assert len(hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]
|
||||
assert len(hmac) == PER_HOP_HMAC_SIZE
|
||||
self.version = 0
|
||||
self.public_key = public_key
|
||||
@@ -127,13 +132,13 @@ class OnionPacket:
|
||||
ret += self.public_key
|
||||
ret += self.hops_data
|
||||
ret += self.hmac
|
||||
if len(ret) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE]:
|
||||
if len(ret) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]:
|
||||
raise Exception('unexpected length {}'.format(len(ret)))
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, b: bytes):
|
||||
if len(b) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE]:
|
||||
if len(b) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]:
|
||||
raise Exception('unexpected length {}'.format(len(b)))
|
||||
version = b[0]
|
||||
if version != 0:
|
||||
@@ -146,27 +151,38 @@ class OnionPacket:
|
||||
|
||||
|
||||
def get_bolt04_onion_key(key_type: bytes, secret: bytes) -> bytes:
|
||||
if key_type not in (b'rho', b'mu', b'um', b'ammag', b'pad'):
|
||||
if key_type not in (b'rho', b'mu', b'um', b'ammag', b'pad', b'blinded_node_id'):
|
||||
raise Exception('invalid key_type {}'.format(key_type))
|
||||
key = hmac_oneshot(key_type, msg=secret, digest=hashlib.sha256)
|
||||
return key
|
||||
|
||||
|
||||
def get_shared_secrets_along_route(payment_path_pubkeys: Sequence[bytes],
|
||||
session_key: bytes) -> Sequence[bytes]:
|
||||
session_key: bytes) -> Tuple[Sequence[bytes], Sequence[bytes]]:
|
||||
num_hops = len(payment_path_pubkeys)
|
||||
hop_shared_secrets = num_hops * [b'']
|
||||
hop_blinded_node_ids = num_hops * [b'']
|
||||
ephemeral_key = session_key
|
||||
# compute shared key for each hop
|
||||
for i in range(0, num_hops):
|
||||
hop_shared_secrets[i] = get_ecdh(ephemeral_key, payment_path_pubkeys[i])
|
||||
hop_blinded_node_ids[i] = get_blinded_node_id(payment_path_pubkeys[i], hop_shared_secrets[i])
|
||||
ephemeral_pubkey = ecc.ECPrivkey(ephemeral_key).get_public_key_bytes()
|
||||
blinding_factor = sha256(ephemeral_pubkey + hop_shared_secrets[i])
|
||||
blinding_factor_int = int.from_bytes(blinding_factor, byteorder="big")
|
||||
ephemeral_key_int = int.from_bytes(ephemeral_key, byteorder="big")
|
||||
ephemeral_key_int = ephemeral_key_int * blinding_factor_int % ecc.CURVE_ORDER
|
||||
ephemeral_key = ephemeral_key_int.to_bytes(32, byteorder="big")
|
||||
return hop_shared_secrets
|
||||
return hop_shared_secrets, hop_blinded_node_ids
|
||||
|
||||
|
||||
def get_blinded_node_id(node_id: bytes, shared_secret: bytes):
|
||||
# blinded node id
|
||||
# B(i) = HMAC256("blinded_node_id", ss(i)) * N(i)
|
||||
ss_bni_hmac = get_bolt04_onion_key(b'blinded_node_id', shared_secret)
|
||||
ss_bni_hmac_int = int.from_bytes(ss_bni_hmac, byteorder="big")
|
||||
blinded_node_id = ecc.ECPubkey(node_id) * ss_bni_hmac_int
|
||||
return blinded_node_id.get_public_key_bytes()
|
||||
|
||||
|
||||
def new_onion_packet(
|
||||
@@ -174,14 +190,31 @@ def new_onion_packet(
|
||||
session_key: bytes,
|
||||
hops_data: Sequence[OnionHopsDataSingle],
|
||||
*,
|
||||
associated_data: bytes,
|
||||
associated_data: bytes = b'',
|
||||
trampoline: bool = False,
|
||||
onion_message: bool = False
|
||||
) -> OnionPacket:
|
||||
num_hops = len(payment_path_pubkeys)
|
||||
assert num_hops == len(hops_data)
|
||||
hop_shared_secrets = get_shared_secrets_along_route(payment_path_pubkeys, session_key)
|
||||
hop_shared_secrets, _ = get_shared_secrets_along_route(payment_path_pubkeys, session_key)
|
||||
|
||||
payload_size = 0
|
||||
for i in range(num_hops):
|
||||
# FIXME: serializing here and again below. cache bytes in OnionHopsDataSingle? _raw_bytes_payload?
|
||||
payload_size += PER_HOP_HMAC_SIZE + len(hops_data[i].to_bytes())
|
||||
if trampoline:
|
||||
data_size = TRAMPOLINE_HOPS_DATA_SIZE
|
||||
elif onion_message:
|
||||
if payload_size <= HOPS_DATA_SIZE:
|
||||
data_size = HOPS_DATA_SIZE
|
||||
else:
|
||||
data_size = ONION_MESSAGE_LARGE_SIZE
|
||||
else:
|
||||
data_size = HOPS_DATA_SIZE
|
||||
|
||||
if payload_size > data_size:
|
||||
raise InvalidPayloadSize(f'payload too big for onion packet (max={data_size}, required={payload_size})')
|
||||
|
||||
data_size = TRAMPOLINE_HOPS_DATA_SIZE if trampoline else HOPS_DATA_SIZE
|
||||
filler = _generate_filler(b'rho', hops_data, hop_shared_secrets, data_size)
|
||||
next_hmac = bytes(PER_HOP_HMAC_SIZE)
|
||||
|
||||
@@ -211,6 +244,30 @@ def new_onion_packet(
|
||||
hmac=next_hmac)
|
||||
|
||||
|
||||
def encrypt_onionmsg_data_tlv(*, shared_secret, **kwargs):
|
||||
rho_key = get_bolt04_onion_key(b'rho', shared_secret)
|
||||
with io.BytesIO() as encrypted_data_tlv_fd:
|
||||
OnionWireSerializer.write_tlv_stream(
|
||||
fd=encrypted_data_tlv_fd,
|
||||
tlv_stream_name='encrypted_data_tlv',
|
||||
**kwargs)
|
||||
encrypted_data_tlv_bytes = encrypted_data_tlv_fd.getvalue()
|
||||
encrypted_recipient_data = chacha20_poly1305_encrypt(
|
||||
key=rho_key, nonce=bytes(12),
|
||||
data=encrypted_data_tlv_bytes)
|
||||
return encrypted_recipient_data
|
||||
|
||||
|
||||
def decrypt_onionmsg_data_tlv(*, shared_secret: bytes, encrypted_recipient_data: bytes) -> dict:
|
||||
rho_key = get_bolt04_onion_key(b'rho', shared_secret)
|
||||
recipient_data_bytes = chacha20_poly1305_decrypt(key=rho_key, nonce=bytes(12), data=encrypted_recipient_data)
|
||||
|
||||
with io.BytesIO(recipient_data_bytes) as fd:
|
||||
recipient_data = OnionWireSerializer.read_tlv_stream(fd=fd, tlv_stream_name='encrypted_data_tlv')
|
||||
|
||||
return recipient_data
|
||||
|
||||
|
||||
def calc_hops_data_for_payment(
|
||||
route: 'LNPaymentRoute',
|
||||
amount_msat: int, # that final recipient receives
|
||||
@@ -299,9 +356,11 @@ class ProcessedOnionPacket(NamedTuple):
|
||||
# TODO replay protection
|
||||
def process_onion_packet(
|
||||
onion_packet: OnionPacket,
|
||||
associated_data: bytes,
|
||||
our_onion_private_key: bytes,
|
||||
is_trampoline=False) -> ProcessedOnionPacket:
|
||||
*,
|
||||
associated_data: bytes = b'',
|
||||
is_trampoline=False,
|
||||
tlv_stream_name='payload') -> ProcessedOnionPacket:
|
||||
if not ecc.ECPubkey.is_pubkey_bytes(onion_packet.public_key):
|
||||
raise InvalidOnionPubkey()
|
||||
shared_secret = get_ecdh(our_onion_private_key, onion_packet.public_key)
|
||||
@@ -319,7 +378,7 @@ def process_onion_packet(
|
||||
padded_header = onion_packet.hops_data + bytes(data_size)
|
||||
next_hops_data = xor_bytes(padded_header, stream_bytes)
|
||||
next_hops_data_fd = io.BytesIO(next_hops_data)
|
||||
hop_data = OnionHopsDataSingle.from_fd(next_hops_data_fd)
|
||||
hop_data = OnionHopsDataSingle.from_fd(next_hops_data_fd, tlv_stream_name=tlv_stream_name)
|
||||
# trampoline
|
||||
trampoline_onion_packet = hop_data.payload.get('trampoline_onion_packet')
|
||||
if trampoline_onion_packet:
|
||||
@@ -427,7 +486,7 @@ def _decode_onion_error(error_packet: bytes, payment_path_pubkeys: Sequence[byte
|
||||
session_key: bytes) -> Tuple[bytes, int]:
|
||||
"""Returns the decoded error bytes, and the index of the sender of the error."""
|
||||
num_hops = len(payment_path_pubkeys)
|
||||
hop_shared_secrets = get_shared_secrets_along_route(payment_path_pubkeys, session_key)
|
||||
hop_shared_secrets, _ = get_shared_secrets_along_route(payment_path_pubkeys, session_key)
|
||||
for i in range(num_hops):
|
||||
ammag_key = get_bolt04_onion_key(b'ammag', hop_shared_secrets[i])
|
||||
um_key = get_bolt04_onion_key(b'um', hop_shared_secrets[i])
|
||||
|
||||
Reference in New Issue
Block a user