lnmsg: additional de/serialization support for onion messages
- add support for `subtype`/`subtypedata` type declarations - add new primitive type `sciddir_or_pubkey` - better assert message for cardinality errors
This commit is contained in:
@@ -147,6 +147,18 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U
|
||||
type_len = 33
|
||||
elif field_type == 'short_channel_id':
|
||||
type_len = 8
|
||||
elif field_type == 'sciddir_or_pubkey':
|
||||
buf = fd.read(1)
|
||||
if buf[0] in [0, 1]:
|
||||
type_len = 9
|
||||
elif buf[0] in [2, 3]:
|
||||
type_len = 33
|
||||
else:
|
||||
raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3")
|
||||
buf += fd.read(type_len - 1)
|
||||
if len(buf) != type_len:
|
||||
raise UnexpectedEndOfStream()
|
||||
return buf
|
||||
|
||||
if count == "...":
|
||||
total_len = -1 # read all
|
||||
@@ -225,6 +237,14 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str],
|
||||
type_len = 33
|
||||
elif field_type == 'short_channel_id':
|
||||
type_len = 8
|
||||
elif field_type == 'sciddir_or_pubkey':
|
||||
assert isinstance(value, bytes)
|
||||
if value[0] in [0, 1]:
|
||||
type_len = 9 # short_channel_id
|
||||
elif value[0] in [2, 3]:
|
||||
type_len = 33 # point
|
||||
else:
|
||||
raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3")
|
||||
total_len = -1
|
||||
if count != "...":
|
||||
if type_len is None:
|
||||
@@ -299,6 +319,8 @@ class LNSerializer:
|
||||
self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]]
|
||||
self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]]
|
||||
|
||||
self.subtypes = {} # type: Dict[str, Dict[str, Sequence[str]]]
|
||||
|
||||
if for_onion_wire:
|
||||
path = os.path.join(os.path.dirname(__file__), "lnwire", "onion_wire.csv")
|
||||
else:
|
||||
@@ -348,9 +370,112 @@ class LNSerializer:
|
||||
assert tlv_stream_name == row[1]
|
||||
assert tlv_record_name == row[2]
|
||||
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row))
|
||||
elif row[0] == "subtype":
|
||||
# subtype,<subtypename>
|
||||
subtypename = row[1]
|
||||
assert subtypename not in self.subtypes, f"duplicate declaration of subtype {subtypename}"
|
||||
self.subtypes[subtypename] = {}
|
||||
elif row[0] == "subtypedata":
|
||||
# subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
|
||||
subtypename = row[1]
|
||||
fieldname = row[2]
|
||||
assert subtypename in self.subtypes, f"subtypedata definition for subtype {subtypename} declared before subtype"
|
||||
assert fieldname not in self.subtypes[subtypename], f"duplicate field definition for {fieldname} for subtype {subtypename}"
|
||||
self.subtypes[subtypename][fieldname] = tuple(row)
|
||||
else:
|
||||
pass # TODO
|
||||
|
||||
def _write_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str],
|
||||
value: Union[List[Dict[str, Any]], Dict[str, Any]]) -> None:
|
||||
assert fd
|
||||
assert field_type in self.subtypes, f"unknown subtype {field_type}"
|
||||
|
||||
if isinstance(count, int):
|
||||
assert count >= 0, f"{count!r} must be non-neg int"
|
||||
elif count == "...":
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"unexpected field count: {count!r}")
|
||||
if count == 0:
|
||||
return
|
||||
|
||||
if count == 1:
|
||||
assert isinstance(value, dict) or isinstance(value, list)
|
||||
values = [value] if isinstance(value, dict) else value
|
||||
else:
|
||||
assert isinstance(value, list), f'{field_type=}, expected value of type list for {count=}'
|
||||
values = value
|
||||
|
||||
if count == '...':
|
||||
count = len(values)
|
||||
else:
|
||||
assert count == len(values), f'{field_type=}, expected {count} but got {len(values)}'
|
||||
if count == 0:
|
||||
return
|
||||
|
||||
for record in values:
|
||||
for subtypename, row in self.subtypes[field_type].items():
|
||||
# subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
|
||||
subtype_field_name = row[2]
|
||||
subtype_field_type = row[3]
|
||||
subtype_field_count_str = row[4]
|
||||
|
||||
subtype_field_count = _resolve_field_count(subtype_field_count_str,
|
||||
vars_dict=record,
|
||||
allow_any=True)
|
||||
|
||||
if subtype_field_name not in record:
|
||||
raise Exception(f'complex field type {field_type} missing element {subtype_field_name}')
|
||||
|
||||
if subtype_field_type in self.subtypes:
|
||||
self._write_complex_field(fd=fd,
|
||||
field_type=subtype_field_type,
|
||||
count=subtype_field_count,
|
||||
value=record[subtype_field_name])
|
||||
else:
|
||||
_write_field(fd=fd,
|
||||
field_type=subtype_field_type,
|
||||
count=subtype_field_count,
|
||||
value=record[subtype_field_name])
|
||||
|
||||
def _read_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str])\
|
||||
-> Union[bytes, List[Dict[str, Any]], Dict[str, Any]]:
|
||||
assert fd
|
||||
if isinstance(count, int):
|
||||
assert count >= 0, f"{count!r} must be non-neg int"
|
||||
elif count == "...":
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"unexpected field count: {count!r}")
|
||||
if count == 0:
|
||||
return b""
|
||||
|
||||
parsedlist = []
|
||||
|
||||
while _num_remaining_bytes_to_read(fd):
|
||||
parsed = {}
|
||||
for subtypename, row in self.subtypes[field_type].items():
|
||||
# subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
|
||||
subtype_field_name = row[2]
|
||||
subtype_field_type = row[3]
|
||||
subtype_field_count_str = row[4]
|
||||
|
||||
subtype_field_count = _resolve_field_count(subtype_field_count_str,
|
||||
vars_dict=parsed,
|
||||
allow_any=True)
|
||||
|
||||
if subtype_field_type in self.subtypes:
|
||||
parsed[subtype_field_name] = self._read_complex_field(fd=fd,
|
||||
field_type=subtype_field_type,
|
||||
count=subtype_field_count)
|
||||
else:
|
||||
parsed[subtype_field_name] = _read_field(fd=fd,
|
||||
field_type=subtype_field_type,
|
||||
count=subtype_field_count)
|
||||
parsedlist.append(parsed)
|
||||
|
||||
return parsedlist if count == '...' or count > 1 else parsedlist[0]
|
||||
|
||||
def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None:
|
||||
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
|
||||
for tlv_record_type, scheme in scheme_map.items(): # note: tlv_record_type is monotonically increasing
|
||||
@@ -372,10 +497,16 @@ class LNSerializer:
|
||||
vars_dict=kwargs[tlv_record_name],
|
||||
allow_any=True)
|
||||
field_value = kwargs[tlv_record_name][field_name]
|
||||
_write_field(fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count,
|
||||
value=field_value)
|
||||
if field_type in self.subtypes:
|
||||
self._write_complex_field(fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count,
|
||||
value=field_value)
|
||||
else:
|
||||
_write_field(fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count,
|
||||
value=field_value)
|
||||
else:
|
||||
raise Exception(f"unexpected row in scheme: {row!r}")
|
||||
_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
|
||||
@@ -417,9 +548,14 @@ class LNSerializer:
|
||||
vars_dict=parsed[tlv_record_name],
|
||||
allow_any=True)
|
||||
#print(f">> count={field_count}. parsed={parsed}")
|
||||
parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count)
|
||||
if field_type in self.subtypes:
|
||||
parsed[tlv_record_name][field_name] = self._read_complex_field(fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count)
|
||||
else:
|
||||
parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count)
|
||||
else:
|
||||
raise Exception(f"unexpected row in scheme: {row!r}")
|
||||
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
|
||||
|
||||
Reference in New Issue
Block a user