1
0

lnmsg: additional de/serialization support for onion messages

- add support for `subtype`/`subtypedata` type declarations
- add new primitive type `sciddir_or_pubkey`
- better assert message for cardinality errors
This commit is contained in:
Sander van Grieken
2024-05-28 10:43:46 +02:00
parent 00bba471ff
commit 7c8dfdecbb

View File

@@ -147,6 +147,18 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U
type_len = 33
elif field_type == 'short_channel_id':
type_len = 8
elif field_type == 'sciddir_or_pubkey':
buf = fd.read(1)
if buf[0] in [0, 1]:
type_len = 9
elif buf[0] in [2, 3]:
type_len = 33
else:
raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3")
buf += fd.read(type_len - 1)
if len(buf) != type_len:
raise UnexpectedEndOfStream()
return buf
if count == "...":
total_len = -1 # read all
@@ -225,6 +237,14 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str],
type_len = 33
elif field_type == 'short_channel_id':
type_len = 8
elif field_type == 'sciddir_or_pubkey':
assert isinstance(value, bytes)
if value[0] in [0, 1]:
type_len = 9 # short_channel_id
elif value[0] in [2, 3]:
type_len = 33 # point
else:
raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3")
total_len = -1
if count != "...":
if type_len is None:
@@ -299,6 +319,8 @@ class LNSerializer:
self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]]
self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]]
self.subtypes = {} # type: Dict[str, Dict[str, Sequence[str]]]
if for_onion_wire:
path = os.path.join(os.path.dirname(__file__), "lnwire", "onion_wire.csv")
else:
@@ -348,9 +370,112 @@ class LNSerializer:
assert tlv_stream_name == row[1]
assert tlv_record_name == row[2]
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row))
elif row[0] == "subtype":
# subtype,<subtypename>
subtypename = row[1]
assert subtypename not in self.subtypes, f"duplicate declaration of subtype {subtypename}"
self.subtypes[subtypename] = {}
elif row[0] == "subtypedata":
# subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
subtypename = row[1]
fieldname = row[2]
assert subtypename in self.subtypes, f"subtypedata definition for subtype {subtypename} declared before subtype"
assert fieldname not in self.subtypes[subtypename], f"duplicate field definition for {fieldname} for subtype {subtypename}"
self.subtypes[subtypename][fieldname] = tuple(row)
else:
pass # TODO
def _write_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str],
value: Union[List[Dict[str, Any]], Dict[str, Any]]) -> None:
assert fd
assert field_type in self.subtypes, f"unknown subtype {field_type}"
if isinstance(count, int):
assert count >= 0, f"{count!r} must be non-neg int"
elif count == "...":
pass
else:
raise Exception(f"unexpected field count: {count!r}")
if count == 0:
return
if count == 1:
assert isinstance(value, dict) or isinstance(value, list)
values = [value] if isinstance(value, dict) else value
else:
assert isinstance(value, list), f'{field_type=}, expected value of type list for {count=}'
values = value
if count == '...':
count = len(values)
else:
assert count == len(values), f'{field_type=}, expected {count} but got {len(values)}'
if count == 0:
return
for record in values:
for subtypename, row in self.subtypes[field_type].items():
# subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
subtype_field_name = row[2]
subtype_field_type = row[3]
subtype_field_count_str = row[4]
subtype_field_count = _resolve_field_count(subtype_field_count_str,
vars_dict=record,
allow_any=True)
if subtype_field_name not in record:
raise Exception(f'complex field type {field_type} missing element {subtype_field_name}')
if subtype_field_type in self.subtypes:
self._write_complex_field(fd=fd,
field_type=subtype_field_type,
count=subtype_field_count,
value=record[subtype_field_name])
else:
_write_field(fd=fd,
field_type=subtype_field_type,
count=subtype_field_count,
value=record[subtype_field_name])
def _read_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str])\
-> Union[bytes, List[Dict[str, Any]], Dict[str, Any]]:
assert fd
if isinstance(count, int):
assert count >= 0, f"{count!r} must be non-neg int"
elif count == "...":
pass
else:
raise Exception(f"unexpected field count: {count!r}")
if count == 0:
return b""
parsedlist = []
while _num_remaining_bytes_to_read(fd):
parsed = {}
for subtypename, row in self.subtypes[field_type].items():
# subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
subtype_field_name = row[2]
subtype_field_type = row[3]
subtype_field_count_str = row[4]
subtype_field_count = _resolve_field_count(subtype_field_count_str,
vars_dict=parsed,
allow_any=True)
if subtype_field_type in self.subtypes:
parsed[subtype_field_name] = self._read_complex_field(fd=fd,
field_type=subtype_field_type,
count=subtype_field_count)
else:
parsed[subtype_field_name] = _read_field(fd=fd,
field_type=subtype_field_type,
count=subtype_field_count)
parsedlist.append(parsed)
return parsedlist if count == '...' or count > 1 else parsedlist[0]
def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None:
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
for tlv_record_type, scheme in scheme_map.items(): # note: tlv_record_type is monotonically increasing
@@ -372,10 +497,16 @@ class LNSerializer:
vars_dict=kwargs[tlv_record_name],
allow_any=True)
field_value = kwargs[tlv_record_name][field_name]
_write_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count,
value=field_value)
if field_type in self.subtypes:
self._write_complex_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count,
value=field_value)
else:
_write_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count,
value=field_value)
else:
raise Exception(f"unexpected row in scheme: {row!r}")
_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
@@ -417,9 +548,14 @@ class LNSerializer:
vars_dict=parsed[tlv_record_name],
allow_any=True)
#print(f">> count={field_count}. parsed={parsed}")
parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count)
if field_type in self.subtypes:
parsed[tlv_record_name][field_name] = self._read_complex_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count)
else:
parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
field_type=field_type,
count=field_count)
else:
raise Exception(f"unexpected row in scheme: {row!r}")
if _num_remaining_bytes_to_read(tlv_record_fd) > 0: