diff --git a/electrum/commands.py b/electrum/commands.py index 6e51fdcdd..e1ccf3a29 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -1516,7 +1516,7 @@ class Commands(Logger): blinded_path = create_blinded_path(session_key, path=path, final_recipient_data={}, dummy_hops=dummy_hops) with io.BytesIO() as blinded_path_fd: - OnionWireSerializer._write_complex_field( + OnionWireSerializer.write_field( fd=blinded_path_fd, field_type='blinded_path', count=1, diff --git a/electrum/lnmsg.py b/electrum/lnmsg.py index da849f22f..eb0b761e2 100644 --- a/electrum/lnmsg.py +++ b/electrum/lnmsg.py @@ -88,8 +88,14 @@ def read_bigsize_int(fd: io.BytesIO) -> Optional[int]: # TODO: maybe if field_type is not "byte", we could return a list of type_len sized chunks? # if field_type is a numeric, we could return a list of ints? -def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> Union[bytes, int]: - if not fd: raise Exception() +def _read_primitive_field( + *, + fd: io.BytesIO, + field_type: str, + count: Union[int, str] +) -> Union[bytes, int]: + if not fd: + raise Exception() if isinstance(count, int): assert count >= 0, f"{count!r} must be non-neg int" elif count == "...": @@ -174,9 +180,15 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U # TODO: maybe for "value" we could accept a list with len "count" of appropriate items -def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str], - value: Union[bytes, int]) -> None: - if not fd: raise Exception() +def _write_primitive_field( + *, + fd: io.BytesIO, + field_type: str, + count: Union[int, str], + value: Union[bytes, int] +) -> None: + if not fd: + raise Exception() if isinstance(count, int): assert count >= 0, f"{count!r} must be non-neg int" elif count == "...": @@ -263,18 +275,18 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str], def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]: if not fd: raise Exception() - tlv_type = _read_field(fd=fd, field_type="bigsize", count=1) - tlv_len = _read_field(fd=fd, field_type="bigsize", count=1) - tlv_val = _read_field(fd=fd, field_type="byte", count=tlv_len) + tlv_type = _read_primitive_field(fd=fd, field_type="bigsize", count=1) + tlv_len = _read_primitive_field(fd=fd, field_type="bigsize", count=1) + tlv_val = _read_primitive_field(fd=fd, field_type="byte", count=tlv_len) return tlv_type, tlv_val def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None: if not fd: raise Exception() tlv_len = len(tlv_val) - _write_field(fd=fd, field_type="bigsize", count=1, value=tlv_type) - _write_field(fd=fd, field_type="bigsize", count=1, value=tlv_len) - _write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val) + _write_primitive_field(fd=fd, field_type="bigsize", count=1, value=tlv_type) + _write_primitive_field(fd=fd, field_type="bigsize", count=1, value=tlv_len) + _write_primitive_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val) def _resolve_field_count(field_count_str: str, *, vars_dict: dict, allow_any=False) -> Union[int, str]: @@ -385,10 +397,19 @@ class LNSerializer: else: pass # TODO - def _write_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str], - value: Union[List[Dict[str, Any]], Dict[str, Any]]) -> None: + def write_field( + self, + *, + fd: io.BytesIO, + field_type: str, + count: Union[int, str], + value: Union[List[Dict[str, Any]], Dict[str, Any]] + ) -> None: assert fd - assert field_type in self.subtypes, f"unknown subtype {field_type}" + + if field_type not in self.subtypes: + _write_primitive_field(fd=fd, field_type=field_type, count=count, value=value) + return if isinstance(count, int): assert count >= 0, f"{count!r} must be non-neg int" @@ -428,22 +449,24 @@ class LNSerializer: if subtype_field_name not in record: raise Exception(f'complex field type {field_type} missing element {subtype_field_name}') - if subtype_field_type in self.subtypes: - self._write_complex_field( - fd=fd, - field_type=subtype_field_type, - count=subtype_field_count, - value=record[subtype_field_name]) - else: - _write_field( - fd=fd, - field_type=subtype_field_type, - count=subtype_field_count, - value=record[subtype_field_name]) + self.write_field( + fd=fd, + field_type=subtype_field_type, + count=subtype_field_count, + value=record[subtype_field_name]) - def _read_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str])\ - -> Union[bytes, List[Dict[str, Any]], Dict[str, Any]]: + def read_field( + self, + *, + fd: io.BytesIO, + field_type: str, + count: Union[int, str] + ) -> Union[bytes, List[Dict[str, Any]], Dict[str, Any]]: assert fd + + if field_type not in self.subtypes: + return _read_primitive_field(fd=fd, field_type=field_type, count=count) + if isinstance(count, int): assert count >= 0, f"{count!r} must be non-neg int" elif count == "...": @@ -468,16 +491,10 @@ class LNSerializer: vars_dict=parsed, allow_any=True) - if subtype_field_type in self.subtypes: - parsed[subtype_field_name] = self._read_complex_field( - fd=fd, - field_type=subtype_field_type, - count=subtype_field_count) - else: - parsed[subtype_field_name] = _read_field( - fd=fd, - field_type=subtype_field_type, - count=subtype_field_count) + parsed[subtype_field_name] = self.read_field( + fd=fd, + field_type=subtype_field_type, + count=subtype_field_count) parsedlist.append(parsed) return parsedlist if count == '...' or count > 1 else parsedlist[0] @@ -503,18 +520,11 @@ class LNSerializer: vars_dict=kwargs[tlv_record_name], allow_any=True) field_value = kwargs[tlv_record_name][field_name] - if field_type in self.subtypes: - self._write_complex_field( - fd=tlv_record_fd, - field_type=field_type, - count=field_count, - value=field_value) - else: - _write_field( - fd=tlv_record_fd, - field_type=field_type, - count=field_count, - value=field_value) + self.write_field( + fd=tlv_record_fd, + field_type=field_type, + count=field_count, + value=field_value) else: raise Exception(f"unexpected row in scheme: {row!r}") _write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue()) @@ -557,16 +567,10 @@ class LNSerializer: vars_dict=parsed[tlv_record_name], allow_any=True) #print(f">> count={field_count}. parsed={parsed}") - if field_type in self.subtypes: - parsed[tlv_record_name][field_name] = self._read_complex_field( - fd=tlv_record_fd, - field_type=field_type, - count=field_count) - else: - parsed[tlv_record_name][field_name] = _read_field( - fd=tlv_record_fd, - field_type=field_type, - count=field_count) + parsed[tlv_record_name][field_name] = self.read_field( + fd=tlv_record_fd, + field_type=field_type, + count=field_count) else: raise Exception(f"unexpected row in scheme: {row!r}") if _num_remaining_bytes_to_read(tlv_record_fd) > 0: @@ -603,10 +607,7 @@ class LNSerializer: except KeyError: field_value = 0 # default mandatory fields to zero #print(f">>> encode_msg. writing field: {field_name}. value={field_value!r}. field_type={field_type!r}. count={field_count!r}") - _write_field(fd=fd, - field_type=field_type, - count=field_count, - value=field_value) + _write_primitive_field(fd=fd, field_type=field_type, count=field_count, value=field_value) #print(f">>> encode_msg. so far: {fd.getvalue().hex()}") else: raise Exception(f"unexpected row in scheme: {row!r}") @@ -651,10 +652,7 @@ class LNSerializer: parsed[tlv_stream_name] = d continue #print(f">> count={field_count}. parsed={parsed}") - parsed[field_name] = _read_field( - fd=fd, - field_type=field_type, - count=field_count) + parsed[field_name] = _read_primitive_field(fd=fd, field_type=field_type, count=field_count) else: raise Exception(f"unexpected row in scheme: {row!r}") except FailedToParseMsg as e: diff --git a/electrum/onion_message.py b/electrum/onion_message.py index 6e1aa322c..3ad2cc93a 100644 --- a/electrum/onion_message.py +++ b/electrum/onion_message.py @@ -193,7 +193,7 @@ def send_onion_message_to( if len(node_id_or_blinded_path) > 33: # assume blinded path with io.BytesIO(node_id_or_blinded_path) as blinded_path_fd: try: - blinded_path = OnionWireSerializer._read_complex_field( + blinded_path = OnionWireSerializer.read_field( fd=blinded_path_fd, field_type='blinded_path', count=1) diff --git a/tests/test_onion_message.py b/tests/test_onion_message.py index 13eee113a..d855d27db 100644 --- a/tests/test_onion_message.py +++ b/tests/test_onion_message.py @@ -194,7 +194,7 @@ class TestOnionMessage(ElectrumTestCase): # TODO: serialization test to test_lnmsg.py with io.BytesIO() as blinded_path_fd: - OnionWireSerializer._write_complex_field( + OnionWireSerializer.write_field( fd=blinded_path_fd, field_type='blinded_path', count=1,