lnmsg: additional de/serialization support for onion messages
- add support for `subtype`/`subtypedata` type declarations - add new primitive type `sciddir_or_pubkey` - better assert message for cardinality errors
This commit is contained in:
@@ -147,6 +147,18 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U
|
|||||||
type_len = 33
|
type_len = 33
|
||||||
elif field_type == 'short_channel_id':
|
elif field_type == 'short_channel_id':
|
||||||
type_len = 8
|
type_len = 8
|
||||||
|
elif field_type == 'sciddir_or_pubkey':
|
||||||
|
buf = fd.read(1)
|
||||||
|
if buf[0] in [0, 1]:
|
||||||
|
type_len = 9
|
||||||
|
elif buf[0] in [2, 3]:
|
||||||
|
type_len = 33
|
||||||
|
else:
|
||||||
|
raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3")
|
||||||
|
buf += fd.read(type_len - 1)
|
||||||
|
if len(buf) != type_len:
|
||||||
|
raise UnexpectedEndOfStream()
|
||||||
|
return buf
|
||||||
|
|
||||||
if count == "...":
|
if count == "...":
|
||||||
total_len = -1 # read all
|
total_len = -1 # read all
|
||||||
@@ -225,6 +237,14 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str],
|
|||||||
type_len = 33
|
type_len = 33
|
||||||
elif field_type == 'short_channel_id':
|
elif field_type == 'short_channel_id':
|
||||||
type_len = 8
|
type_len = 8
|
||||||
|
elif field_type == 'sciddir_or_pubkey':
|
||||||
|
assert isinstance(value, bytes)
|
||||||
|
if value[0] in [0, 1]:
|
||||||
|
type_len = 9 # short_channel_id
|
||||||
|
elif value[0] in [2, 3]:
|
||||||
|
type_len = 33 # point
|
||||||
|
else:
|
||||||
|
raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3")
|
||||||
total_len = -1
|
total_len = -1
|
||||||
if count != "...":
|
if count != "...":
|
||||||
if type_len is None:
|
if type_len is None:
|
||||||
@@ -299,6 +319,8 @@ class LNSerializer:
|
|||||||
self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]]
|
self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]]
|
||||||
self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]]
|
self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]]
|
||||||
|
|
||||||
|
self.subtypes = {} # type: Dict[str, Dict[str, Sequence[str]]]
|
||||||
|
|
||||||
if for_onion_wire:
|
if for_onion_wire:
|
||||||
path = os.path.join(os.path.dirname(__file__), "lnwire", "onion_wire.csv")
|
path = os.path.join(os.path.dirname(__file__), "lnwire", "onion_wire.csv")
|
||||||
else:
|
else:
|
||||||
@@ -348,9 +370,112 @@ class LNSerializer:
|
|||||||
assert tlv_stream_name == row[1]
|
assert tlv_stream_name == row[1]
|
||||||
assert tlv_record_name == row[2]
|
assert tlv_record_name == row[2]
|
||||||
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row))
|
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row))
|
||||||
|
elif row[0] == "subtype":
|
||||||
|
# subtype,<subtypename>
|
||||||
|
subtypename = row[1]
|
||||||
|
assert subtypename not in self.subtypes, f"duplicate declaration of subtype {subtypename}"
|
||||||
|
self.subtypes[subtypename] = {}
|
||||||
|
elif row[0] == "subtypedata":
|
||||||
|
# subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
|
||||||
|
subtypename = row[1]
|
||||||
|
fieldname = row[2]
|
||||||
|
assert subtypename in self.subtypes, f"subtypedata definition for subtype {subtypename} declared before subtype"
|
||||||
|
assert fieldname not in self.subtypes[subtypename], f"duplicate field definition for {fieldname} for subtype {subtypename}"
|
||||||
|
self.subtypes[subtypename][fieldname] = tuple(row)
|
||||||
else:
|
else:
|
||||||
pass # TODO
|
pass # TODO
|
||||||
|
|
||||||
|
def _write_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str],
|
||||||
|
value: Union[List[Dict[str, Any]], Dict[str, Any]]) -> None:
|
||||||
|
assert fd
|
||||||
|
assert field_type in self.subtypes, f"unknown subtype {field_type}"
|
||||||
|
|
||||||
|
if isinstance(count, int):
|
||||||
|
assert count >= 0, f"{count!r} must be non-neg int"
|
||||||
|
elif count == "...":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise Exception(f"unexpected field count: {count!r}")
|
||||||
|
if count == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if count == 1:
|
||||||
|
assert isinstance(value, dict) or isinstance(value, list)
|
||||||
|
values = [value] if isinstance(value, dict) else value
|
||||||
|
else:
|
||||||
|
assert isinstance(value, list), f'{field_type=}, expected value of type list for {count=}'
|
||||||
|
values = value
|
||||||
|
|
||||||
|
if count == '...':
|
||||||
|
count = len(values)
|
||||||
|
else:
|
||||||
|
assert count == len(values), f'{field_type=}, expected {count} but got {len(values)}'
|
||||||
|
if count == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for record in values:
|
||||||
|
for subtypename, row in self.subtypes[field_type].items():
|
||||||
|
# subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
|
||||||
|
subtype_field_name = row[2]
|
||||||
|
subtype_field_type = row[3]
|
||||||
|
subtype_field_count_str = row[4]
|
||||||
|
|
||||||
|
subtype_field_count = _resolve_field_count(subtype_field_count_str,
|
||||||
|
vars_dict=record,
|
||||||
|
allow_any=True)
|
||||||
|
|
||||||
|
if subtype_field_name not in record:
|
||||||
|
raise Exception(f'complex field type {field_type} missing element {subtype_field_name}')
|
||||||
|
|
||||||
|
if subtype_field_type in self.subtypes:
|
||||||
|
self._write_complex_field(fd=fd,
|
||||||
|
field_type=subtype_field_type,
|
||||||
|
count=subtype_field_count,
|
||||||
|
value=record[subtype_field_name])
|
||||||
|
else:
|
||||||
|
_write_field(fd=fd,
|
||||||
|
field_type=subtype_field_type,
|
||||||
|
count=subtype_field_count,
|
||||||
|
value=record[subtype_field_name])
|
||||||
|
|
||||||
|
def _read_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str])\
|
||||||
|
-> Union[bytes, List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
|
assert fd
|
||||||
|
if isinstance(count, int):
|
||||||
|
assert count >= 0, f"{count!r} must be non-neg int"
|
||||||
|
elif count == "...":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise Exception(f"unexpected field count: {count!r}")
|
||||||
|
if count == 0:
|
||||||
|
return b""
|
||||||
|
|
||||||
|
parsedlist = []
|
||||||
|
|
||||||
|
while _num_remaining_bytes_to_read(fd):
|
||||||
|
parsed = {}
|
||||||
|
for subtypename, row in self.subtypes[field_type].items():
|
||||||
|
# subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
|
||||||
|
subtype_field_name = row[2]
|
||||||
|
subtype_field_type = row[3]
|
||||||
|
subtype_field_count_str = row[4]
|
||||||
|
|
||||||
|
subtype_field_count = _resolve_field_count(subtype_field_count_str,
|
||||||
|
vars_dict=parsed,
|
||||||
|
allow_any=True)
|
||||||
|
|
||||||
|
if subtype_field_type in self.subtypes:
|
||||||
|
parsed[subtype_field_name] = self._read_complex_field(fd=fd,
|
||||||
|
field_type=subtype_field_type,
|
||||||
|
count=subtype_field_count)
|
||||||
|
else:
|
||||||
|
parsed[subtype_field_name] = _read_field(fd=fd,
|
||||||
|
field_type=subtype_field_type,
|
||||||
|
count=subtype_field_count)
|
||||||
|
parsedlist.append(parsed)
|
||||||
|
|
||||||
|
return parsedlist if count == '...' or count > 1 else parsedlist[0]
|
||||||
|
|
||||||
def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None:
|
def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None:
|
||||||
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
|
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
|
||||||
for tlv_record_type, scheme in scheme_map.items(): # note: tlv_record_type is monotonically increasing
|
for tlv_record_type, scheme in scheme_map.items(): # note: tlv_record_type is monotonically increasing
|
||||||
@@ -372,10 +497,16 @@ class LNSerializer:
|
|||||||
vars_dict=kwargs[tlv_record_name],
|
vars_dict=kwargs[tlv_record_name],
|
||||||
allow_any=True)
|
allow_any=True)
|
||||||
field_value = kwargs[tlv_record_name][field_name]
|
field_value = kwargs[tlv_record_name][field_name]
|
||||||
_write_field(fd=tlv_record_fd,
|
if field_type in self.subtypes:
|
||||||
field_type=field_type,
|
self._write_complex_field(fd=tlv_record_fd,
|
||||||
count=field_count,
|
field_type=field_type,
|
||||||
value=field_value)
|
count=field_count,
|
||||||
|
value=field_value)
|
||||||
|
else:
|
||||||
|
_write_field(fd=tlv_record_fd,
|
||||||
|
field_type=field_type,
|
||||||
|
count=field_count,
|
||||||
|
value=field_value)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"unexpected row in scheme: {row!r}")
|
raise Exception(f"unexpected row in scheme: {row!r}")
|
||||||
_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
|
_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
|
||||||
@@ -417,9 +548,14 @@ class LNSerializer:
|
|||||||
vars_dict=parsed[tlv_record_name],
|
vars_dict=parsed[tlv_record_name],
|
||||||
allow_any=True)
|
allow_any=True)
|
||||||
#print(f">> count={field_count}. parsed={parsed}")
|
#print(f">> count={field_count}. parsed={parsed}")
|
||||||
parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
|
if field_type in self.subtypes:
|
||||||
field_type=field_type,
|
parsed[tlv_record_name][field_name] = self._read_complex_field(fd=tlv_record_fd,
|
||||||
count=field_count)
|
field_type=field_type,
|
||||||
|
count=field_count)
|
||||||
|
else:
|
||||||
|
parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
|
||||||
|
field_type=field_type,
|
||||||
|
count=field_count)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"unexpected row in scheme: {row!r}")
|
raise Exception(f"unexpected row in scheme: {row!r}")
|
||||||
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
|
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user