lnmsg: support both primitive and complex types (subtypes) in LNSerializer.
This renames lnmsg._{read,write}_field to lnmsg._{read,write}_primitive_field, renames
LNSerializer._{read,write}_complex_type to LNSerializer.{read,write}_field and allows
LNSerializer.{read,write}_field to handle both primitive and complex types.
Also makes these funcs public, as these encodings are used outside of lnmsg as well
(e.g. encoding blinded paths in BOLT12 invoice_request)
This commit is contained in:
@@ -1516,7 +1516,7 @@ class Commands(Logger):
|
||||
blinded_path = create_blinded_path(session_key, path=path, final_recipient_data={}, dummy_hops=dummy_hops)
|
||||
|
||||
with io.BytesIO() as blinded_path_fd:
|
||||
OnionWireSerializer._write_complex_field(
|
||||
OnionWireSerializer.write_field(
|
||||
fd=blinded_path_fd,
|
||||
field_type='blinded_path',
|
||||
count=1,
|
||||
|
||||
@@ -88,8 +88,14 @@ def read_bigsize_int(fd: io.BytesIO) -> Optional[int]:
|
||||
|
||||
# TODO: maybe if field_type is not "byte", we could return a list of type_len sized chunks?
|
||||
# if field_type is a numeric, we could return a list of ints?
|
||||
def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> Union[bytes, int]:
|
||||
if not fd: raise Exception()
|
||||
def _read_primitive_field(
|
||||
*,
|
||||
fd: io.BytesIO,
|
||||
field_type: str,
|
||||
count: Union[int, str]
|
||||
) -> Union[bytes, int]:
|
||||
if not fd:
|
||||
raise Exception()
|
||||
if isinstance(count, int):
|
||||
assert count >= 0, f"{count!r} must be non-neg int"
|
||||
elif count == "...":
|
||||
@@ -174,9 +180,15 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U
|
||||
|
||||
|
||||
# TODO: maybe for "value" we could accept a list with len "count" of appropriate items
|
||||
def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str],
|
||||
value: Union[bytes, int]) -> None:
|
||||
if not fd: raise Exception()
|
||||
def _write_primitive_field(
|
||||
*,
|
||||
fd: io.BytesIO,
|
||||
field_type: str,
|
||||
count: Union[int, str],
|
||||
value: Union[bytes, int]
|
||||
) -> None:
|
||||
if not fd:
|
||||
raise Exception()
|
||||
if isinstance(count, int):
|
||||
assert count >= 0, f"{count!r} must be non-neg int"
|
||||
elif count == "...":
|
||||
@@ -263,18 +275,18 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str],
|
||||
|
||||
def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]:
|
||||
if not fd: raise Exception()
|
||||
tlv_type = _read_field(fd=fd, field_type="bigsize", count=1)
|
||||
tlv_len = _read_field(fd=fd, field_type="bigsize", count=1)
|
||||
tlv_val = _read_field(fd=fd, field_type="byte", count=tlv_len)
|
||||
tlv_type = _read_primitive_field(fd=fd, field_type="bigsize", count=1)
|
||||
tlv_len = _read_primitive_field(fd=fd, field_type="bigsize", count=1)
|
||||
tlv_val = _read_primitive_field(fd=fd, field_type="byte", count=tlv_len)
|
||||
return tlv_type, tlv_val
|
||||
|
||||
|
||||
def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None:
|
||||
if not fd: raise Exception()
|
||||
tlv_len = len(tlv_val)
|
||||
_write_field(fd=fd, field_type="bigsize", count=1, value=tlv_type)
|
||||
_write_field(fd=fd, field_type="bigsize", count=1, value=tlv_len)
|
||||
_write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
|
||||
_write_primitive_field(fd=fd, field_type="bigsize", count=1, value=tlv_type)
|
||||
_write_primitive_field(fd=fd, field_type="bigsize", count=1, value=tlv_len)
|
||||
_write_primitive_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
|
||||
|
||||
|
||||
def _resolve_field_count(field_count_str: str, *, vars_dict: dict, allow_any=False) -> Union[int, str]:
|
||||
@@ -385,10 +397,19 @@ class LNSerializer:
|
||||
else:
|
||||
pass # TODO
|
||||
|
||||
def _write_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str],
|
||||
value: Union[List[Dict[str, Any]], Dict[str, Any]]) -> None:
|
||||
def write_field(
|
||||
self,
|
||||
*,
|
||||
fd: io.BytesIO,
|
||||
field_type: str,
|
||||
count: Union[int, str],
|
||||
value: Union[List[Dict[str, Any]], Dict[str, Any]]
|
||||
) -> None:
|
||||
assert fd
|
||||
assert field_type in self.subtypes, f"unknown subtype {field_type}"
|
||||
|
||||
if field_type not in self.subtypes:
|
||||
_write_primitive_field(fd=fd, field_type=field_type, count=count, value=value)
|
||||
return
|
||||
|
||||
if isinstance(count, int):
|
||||
assert count >= 0, f"{count!r} must be non-neg int"
|
||||
@@ -428,22 +449,24 @@ class LNSerializer:
|
||||
if subtype_field_name not in record:
|
||||
raise Exception(f'complex field type {field_type} missing element {subtype_field_name}')
|
||||
|
||||
if subtype_field_type in self.subtypes:
|
||||
self._write_complex_field(
|
||||
fd=fd,
|
||||
field_type=subtype_field_type,
|
||||
count=subtype_field_count,
|
||||
value=record[subtype_field_name])
|
||||
else:
|
||||
_write_field(
|
||||
fd=fd,
|
||||
field_type=subtype_field_type,
|
||||
count=subtype_field_count,
|
||||
value=record[subtype_field_name])
|
||||
self.write_field(
|
||||
fd=fd,
|
||||
field_type=subtype_field_type,
|
||||
count=subtype_field_count,
|
||||
value=record[subtype_field_name])
|
||||
|
||||
def _read_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str])\
|
||||
-> Union[bytes, List[Dict[str, Any]], Dict[str, Any]]:
|
||||
def read_field(
|
||||
self,
|
||||
*,
|
||||
fd: io.BytesIO,
|
||||
field_type: str,
|
||||
count: Union[int, str]
|
||||
) -> Union[bytes, List[Dict[str, Any]], Dict[str, Any]]:
|
||||
assert fd
|
||||
|
||||
if field_type not in self.subtypes:
|
||||
return _read_primitive_field(fd=fd, field_type=field_type, count=count)
|
||||
|
||||
if isinstance(count, int):
|
||||
assert count >= 0, f"{count!r} must be non-neg int"
|
||||
elif count == "...":
|
||||
@@ -468,16 +491,10 @@ class LNSerializer:
|
||||
vars_dict=parsed,
|
||||
allow_any=True)
|
||||
|
||||
if subtype_field_type in self.subtypes:
|
||||
parsed[subtype_field_name] = self._read_complex_field(
|
||||
fd=fd,
|
||||
field_type=subtype_field_type,
|
||||
count=subtype_field_count)
|
||||
else:
|
||||
parsed[subtype_field_name] = _read_field(
|
||||
fd=fd,
|
||||
field_type=subtype_field_type,
|
||||
count=subtype_field_count)
|
||||
parsed[subtype_field_name] = self.read_field(
|
||||
fd=fd,
|
||||
field_type=subtype_field_type,
|
||||
count=subtype_field_count)
|
||||
parsedlist.append(parsed)
|
||||
|
||||
return parsedlist if count == '...' or count > 1 else parsedlist[0]
|
||||
@@ -503,18 +520,11 @@ class LNSerializer:
|
||||
vars_dict=kwargs[tlv_record_name],
|
||||
allow_any=True)
|
||||
field_value = kwargs[tlv_record_name][field_name]
|
||||
if field_type in self.subtypes:
|
||||
self._write_complex_field(
|
||||
fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count,
|
||||
value=field_value)
|
||||
else:
|
||||
_write_field(
|
||||
fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count,
|
||||
value=field_value)
|
||||
self.write_field(
|
||||
fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count,
|
||||
value=field_value)
|
||||
else:
|
||||
raise Exception(f"unexpected row in scheme: {row!r}")
|
||||
_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
|
||||
@@ -557,16 +567,10 @@ class LNSerializer:
|
||||
vars_dict=parsed[tlv_record_name],
|
||||
allow_any=True)
|
||||
#print(f">> count={field_count}. parsed={parsed}")
|
||||
if field_type in self.subtypes:
|
||||
parsed[tlv_record_name][field_name] = self._read_complex_field(
|
||||
fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count)
|
||||
else:
|
||||
parsed[tlv_record_name][field_name] = _read_field(
|
||||
fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count)
|
||||
parsed[tlv_record_name][field_name] = self.read_field(
|
||||
fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count)
|
||||
else:
|
||||
raise Exception(f"unexpected row in scheme: {row!r}")
|
||||
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
|
||||
@@ -603,10 +607,7 @@ class LNSerializer:
|
||||
except KeyError:
|
||||
field_value = 0 # default mandatory fields to zero
|
||||
#print(f">>> encode_msg. writing field: {field_name}. value={field_value!r}. field_type={field_type!r}. count={field_count!r}")
|
||||
_write_field(fd=fd,
|
||||
field_type=field_type,
|
||||
count=field_count,
|
||||
value=field_value)
|
||||
_write_primitive_field(fd=fd, field_type=field_type, count=field_count, value=field_value)
|
||||
#print(f">>> encode_msg. so far: {fd.getvalue().hex()}")
|
||||
else:
|
||||
raise Exception(f"unexpected row in scheme: {row!r}")
|
||||
@@ -651,10 +652,7 @@ class LNSerializer:
|
||||
parsed[tlv_stream_name] = d
|
||||
continue
|
||||
#print(f">> count={field_count}. parsed={parsed}")
|
||||
parsed[field_name] = _read_field(
|
||||
fd=fd,
|
||||
field_type=field_type,
|
||||
count=field_count)
|
||||
parsed[field_name] = _read_primitive_field(fd=fd, field_type=field_type, count=field_count)
|
||||
else:
|
||||
raise Exception(f"unexpected row in scheme: {row!r}")
|
||||
except FailedToParseMsg as e:
|
||||
|
||||
@@ -193,7 +193,7 @@ def send_onion_message_to(
|
||||
if len(node_id_or_blinded_path) > 33: # assume blinded path
|
||||
with io.BytesIO(node_id_or_blinded_path) as blinded_path_fd:
|
||||
try:
|
||||
blinded_path = OnionWireSerializer._read_complex_field(
|
||||
blinded_path = OnionWireSerializer.read_field(
|
||||
fd=blinded_path_fd,
|
||||
field_type='blinded_path',
|
||||
count=1)
|
||||
|
||||
@@ -194,7 +194,7 @@ class TestOnionMessage(ElectrumTestCase):
|
||||
|
||||
# TODO: serialization test to test_lnmsg.py
|
||||
with io.BytesIO() as blinded_path_fd:
|
||||
OnionWireSerializer._write_complex_field(
|
||||
OnionWireSerializer.write_field(
|
||||
fd=blinded_path_fd,
|
||||
field_type='blinded_path',
|
||||
count=1,
|
||||
|
||||
Reference in New Issue
Block a user