1
0

lnmsg: support both primitive and complex types (subtypes) in LNSerializer.

This renames lnmsg._{read,write}_field to lnmsg._{read,write}_primitive_field, renames
LNSerializer._{read,write}_complex_type to LNSerializer.{read,write}_field and allows
LNSerializer.{read,write}_field to handle both primitive and complex types.

Also makes these funcs public, as these encodings are used outside of lnmsg as well
(e.g. encoding blinded paths in BOLT12 invoice_request)
This commit is contained in:
Sander van Grieken
2025-02-12 14:28:30 +01:00
parent e216f1b324
commit 6e35ffe4b5
4 changed files with 69 additions and 71 deletions

View File

@@ -1516,7 +1516,7 @@ class Commands(Logger):
blinded_path = create_blinded_path(session_key, path=path, final_recipient_data={}, dummy_hops=dummy_hops)
with io.BytesIO() as blinded_path_fd:
OnionWireSerializer._write_complex_field(
OnionWireSerializer.write_field(
fd=blinded_path_fd,
field_type='blinded_path',
count=1,

View File

@@ -88,8 +88,14 @@ def read_bigsize_int(fd: io.BytesIO) -> Optional[int]:
# TODO: maybe if field_type is not "byte", we could return a list of type_len sized chunks?
# if field_type is a numeric, we could return a list of ints?
def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> Union[bytes, int]:
if not fd: raise Exception()
def _read_primitive_field(
*,
fd: io.BytesIO,
field_type: str,
count: Union[int, str]
) -> Union[bytes, int]:
if not fd:
raise Exception()
if isinstance(count, int):
assert count >= 0, f"{count!r} must be non-neg int"
elif count == "...":
@@ -174,9 +180,15 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U
# TODO: maybe for "value" we could accept a list with len "count" of appropriate items
def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str],
value: Union[bytes, int]) -> None:
if not fd: raise Exception()
def _write_primitive_field(
*,
fd: io.BytesIO,
field_type: str,
count: Union[int, str],
value: Union[bytes, int]
) -> None:
if not fd:
raise Exception()
if isinstance(count, int):
assert count >= 0, f"{count!r} must be non-neg int"
elif count == "...":
@@ -263,18 +275,18 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str],
def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]:
if not fd: raise Exception()
tlv_type = _read_field(fd=fd, field_type="bigsize", count=1)
tlv_len = _read_field(fd=fd, field_type="bigsize", count=1)
tlv_val = _read_field(fd=fd, field_type="byte", count=tlv_len)
tlv_type = _read_primitive_field(fd=fd, field_type="bigsize", count=1)
tlv_len = _read_primitive_field(fd=fd, field_type="bigsize", count=1)
tlv_val = _read_primitive_field(fd=fd, field_type="byte", count=tlv_len)
return tlv_type, tlv_val
def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None:
if not fd: raise Exception()
tlv_len = len(tlv_val)
_write_field(fd=fd, field_type="bigsize", count=1, value=tlv_type)
_write_field(fd=fd, field_type="bigsize", count=1, value=tlv_len)
_write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
_write_primitive_field(fd=fd, field_type="bigsize", count=1, value=tlv_type)
_write_primitive_field(fd=fd, field_type="bigsize", count=1, value=tlv_len)
_write_primitive_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
def _resolve_field_count(field_count_str: str, *, vars_dict: dict, allow_any=False) -> Union[int, str]:
@@ -385,10 +397,19 @@ class LNSerializer:
else:
pass # TODO
def _write_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str],
value: Union[List[Dict[str, Any]], Dict[str, Any]]) -> None:
def write_field(
self,
*,
fd: io.BytesIO,
field_type: str,
count: Union[int, str],
value: Union[List[Dict[str, Any]], Dict[str, Any]]
) -> None:
assert fd
assert field_type in self.subtypes, f"unknown subtype {field_type}"
if field_type not in self.subtypes:
_write_primitive_field(fd=fd, field_type=field_type, count=count, value=value)
return
if isinstance(count, int):
assert count >= 0, f"{count!r} must be non-neg int"
@@ -428,22 +449,24 @@ class LNSerializer:
if subtype_field_name not in record:
raise Exception(f'complex field type {field_type} missing element {subtype_field_name}')
if subtype_field_type in self.subtypes:
self._write_complex_field(
fd=fd,
field_type=subtype_field_type,
count=subtype_field_count,
value=record[subtype_field_name])
else:
_write_field(
fd=fd,
field_type=subtype_field_type,
count=subtype_field_count,
value=record[subtype_field_name])
self.write_field(
fd=fd,
field_type=subtype_field_type,
count=subtype_field_count,
value=record[subtype_field_name])
def _read_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str])\
-> Union[bytes, List[Dict[str, Any]], Dict[str, Any]]:
def read_field(
self,
*,
fd: io.BytesIO,
field_type: str,
count: Union[int, str]
) -> Union[bytes, List[Dict[str, Any]], Dict[str, Any]]:
assert fd
if field_type not in self.subtypes:
return _read_primitive_field(fd=fd, field_type=field_type, count=count)
if isinstance(count, int):
assert count >= 0, f"{count!r} must be non-neg int"
elif count == "...":
@@ -468,16 +491,10 @@ class LNSerializer:
vars_dict=parsed,
allow_any=True)
if subtype_field_type in self.subtypes:
parsed[subtype_field_name] = self._read_complex_field(
fd=fd,
field_type=subtype_field_type,
count=subtype_field_count)
else:
parsed[subtype_field_name] = _read_field(
fd=fd,
field_type=subtype_field_type,
count=subtype_field_count)
parsed[subtype_field_name] = self.read_field(
fd=fd,
field_type=subtype_field_type,
count=subtype_field_count)
parsedlist.append(parsed)
return parsedlist if count == '...' or count > 1 else parsedlist[0]
@@ -503,18 +520,11 @@ class LNSerializer:
vars_dict=kwargs[tlv_record_name],
allow_any=True)
field_value = kwargs[tlv_record_name][field_name]
if field_type in self.subtypes:
self._write_complex_field(
fd=tlv_record_fd,
field_type=field_type,
count=field_count,
value=field_value)
else:
_write_field(
fd=tlv_record_fd,
field_type=field_type,
count=field_count,
value=field_value)
self.write_field(
fd=tlv_record_fd,
field_type=field_type,
count=field_count,
value=field_value)
else:
raise Exception(f"unexpected row in scheme: {row!r}")
_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
@@ -557,16 +567,10 @@ class LNSerializer:
vars_dict=parsed[tlv_record_name],
allow_any=True)
#print(f">> count={field_count}. parsed={parsed}")
if field_type in self.subtypes:
parsed[tlv_record_name][field_name] = self._read_complex_field(
fd=tlv_record_fd,
field_type=field_type,
count=field_count)
else:
parsed[tlv_record_name][field_name] = _read_field(
fd=tlv_record_fd,
field_type=field_type,
count=field_count)
parsed[tlv_record_name][field_name] = self.read_field(
fd=tlv_record_fd,
field_type=field_type,
count=field_count)
else:
raise Exception(f"unexpected row in scheme: {row!r}")
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
@@ -603,10 +607,7 @@ class LNSerializer:
except KeyError:
field_value = 0 # default mandatory fields to zero
#print(f">>> encode_msg. writing field: {field_name}. value={field_value!r}. field_type={field_type!r}. count={field_count!r}")
_write_field(fd=fd,
field_type=field_type,
count=field_count,
value=field_value)
_write_primitive_field(fd=fd, field_type=field_type, count=field_count, value=field_value)
#print(f">>> encode_msg. so far: {fd.getvalue().hex()}")
else:
raise Exception(f"unexpected row in scheme: {row!r}")
@@ -651,10 +652,7 @@ class LNSerializer:
parsed[tlv_stream_name] = d
continue
#print(f">> count={field_count}. parsed={parsed}")
parsed[field_name] = _read_field(
fd=fd,
field_type=field_type,
count=field_count)
parsed[field_name] = _read_primitive_field(fd=fd, field_type=field_type, count=field_count)
else:
raise Exception(f"unexpected row in scheme: {row!r}")
except FailedToParseMsg as e:

View File

@@ -193,7 +193,7 @@ def send_onion_message_to(
if len(node_id_or_blinded_path) > 33: # assume blinded path
with io.BytesIO(node_id_or_blinded_path) as blinded_path_fd:
try:
blinded_path = OnionWireSerializer._read_complex_field(
blinded_path = OnionWireSerializer.read_field(
fd=blinded_path_fd,
field_type='blinded_path',
count=1)

View File

@@ -194,7 +194,7 @@ class TestOnionMessage(ElectrumTestCase):
# TODO: serialization test to test_lnmsg.py
with io.BytesIO() as blinded_path_fd:
OnionWireSerializer._write_complex_field(
OnionWireSerializer.write_field(
fd=blinded_path_fd,
field_type='blinded_path',
count=1,