diff --git a/electrum/lnmsg.py b/electrum/lnmsg.py index 1ff50dc54..ec217dc2d 100644 --- a/electrum/lnmsg.py +++ b/electrum/lnmsg.py @@ -147,6 +147,18 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U type_len = 33 elif field_type == 'short_channel_id': type_len = 8 + elif field_type == 'sciddir_or_pubkey': + buf = fd.read(1) + if buf[0] in [0, 1]: + type_len = 9 + elif buf[0] in [2, 3]: + type_len = 33 + else: + raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3") + buf += fd.read(type_len - 1) + if len(buf) != type_len: + raise UnexpectedEndOfStream() + return buf if count == "...": total_len = -1 # read all @@ -225,6 +237,14 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str], type_len = 33 elif field_type == 'short_channel_id': type_len = 8 + elif field_type == 'sciddir_or_pubkey': + assert isinstance(value, bytes) + if value[0] in [0, 1]: + type_len = 9 # short_channel_id + elif value[0] in [2, 3]: + type_len = 33 # point + else: + raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3") total_len = -1 if count != "...": if type_len is None: @@ -299,6 +319,8 @@ class LNSerializer: self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]] self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]] + self.subtypes = {} # type: Dict[str, Dict[str, Sequence[str]]] + if for_onion_wire: path = os.path.join(os.path.dirname(__file__), "lnwire", "onion_wire.csv") else: @@ -348,9 +370,112 @@ class LNSerializer: assert tlv_stream_name == row[1] assert tlv_record_name == row[2] self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row)) + elif row[0] == "subtype": + # subtype, + subtypename = row[1] + assert subtypename not in self.subtypes, f"duplicate declaration of subtype {subtypename}" + self.subtypes[subtypename] = {} + elif row[0] == "subtypedata": + # subtypedata,,,,[] + subtypename = row[1] + fieldname = row[2] + assert subtypename in self.subtypes, f"subtypedata definition for subtype {subtypename} declared before subtype" + assert fieldname not in self.subtypes[subtypename], f"duplicate field definition for {fieldname} for subtype {subtypename}" + self.subtypes[subtypename][fieldname] = tuple(row) else: pass # TODO + def _write_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str], + value: Union[List[Dict[str, Any]], Dict[str, Any]]) -> None: + assert fd + assert field_type in self.subtypes, f"unknown subtype {field_type}" + + if isinstance(count, int): + assert count >= 0, f"{count!r} must be non-neg int" + elif count == "...": + pass + else: + raise Exception(f"unexpected field count: {count!r}") + if count == 0: + return + + if count == 1: + assert isinstance(value, dict) or isinstance(value, list) + values = [value] if isinstance(value, dict) else value + else: + assert isinstance(value, list), f'{field_type=}, expected value of type list for {count=}' + values = value + + if count == '...': + count = len(values) + else: + assert count == len(values), f'{field_type=}, expected {count} but got {len(values)}' + if count == 0: + return + + for record in values: + for subtypename, row in self.subtypes[field_type].items(): + # subtypedata,,,,[] + subtype_field_name = row[2] + subtype_field_type = row[3] + subtype_field_count_str = row[4] + + subtype_field_count = _resolve_field_count(subtype_field_count_str, + vars_dict=record, + allow_any=True) + + if subtype_field_name not in record: + raise Exception(f'complex field type {field_type} missing element {subtype_field_name}') + + if subtype_field_type in self.subtypes: + self._write_complex_field(fd=fd, + field_type=subtype_field_type, + count=subtype_field_count, + value=record[subtype_field_name]) + else: + _write_field(fd=fd, + field_type=subtype_field_type, + count=subtype_field_count, + value=record[subtype_field_name]) + + def _read_complex_field(self, *, fd: io.BytesIO, field_type: str, count: Union[int, str])\ + -> Union[bytes, List[Dict[str, Any]], Dict[str, Any]]: + assert fd + if isinstance(count, int): + assert count >= 0, f"{count!r} must be non-neg int" + elif count == "...": + pass + else: + raise Exception(f"unexpected field count: {count!r}") + if count == 0: + return b"" + + parsedlist = [] + + while _num_remaining_bytes_to_read(fd): + parsed = {} + for subtypename, row in self.subtypes[field_type].items(): + # subtypedata,,,,[] + subtype_field_name = row[2] + subtype_field_type = row[3] + subtype_field_count_str = row[4] + + subtype_field_count = _resolve_field_count(subtype_field_count_str, + vars_dict=parsed, + allow_any=True) + + if subtype_field_type in self.subtypes: + parsed[subtype_field_name] = self._read_complex_field(fd=fd, + field_type=subtype_field_type, + count=subtype_field_count) + else: + parsed[subtype_field_name] = _read_field(fd=fd, + field_type=subtype_field_type, + count=subtype_field_count) + parsedlist.append(parsed) + + return parsedlist if count == '...' or count > 1 else parsedlist[0] + def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None: scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name] for tlv_record_type, scheme in scheme_map.items(): # note: tlv_record_type is monotonically increasing @@ -372,10 +497,16 @@ class LNSerializer: vars_dict=kwargs[tlv_record_name], allow_any=True) field_value = kwargs[tlv_record_name][field_name] - _write_field(fd=tlv_record_fd, - field_type=field_type, - count=field_count, - value=field_value) + if field_type in self.subtypes: + self._write_complex_field(fd=tlv_record_fd, + field_type=field_type, + count=field_count, + value=field_value) + else: + _write_field(fd=tlv_record_fd, + field_type=field_type, + count=field_count, + value=field_value) else: raise Exception(f"unexpected row in scheme: {row!r}") _write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue()) @@ -417,9 +548,14 @@ class LNSerializer: vars_dict=parsed[tlv_record_name], allow_any=True) #print(f">> count={field_count}. parsed={parsed}") - parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd, - field_type=field_type, - count=field_count) + if field_type in self.subtypes: + parsed[tlv_record_name][field_name] = self._read_complex_field(fd=tlv_record_fd, + field_type=field_type, + count=field_count) + else: + parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd, + field_type=field_type, + count=field_count) else: raise Exception(f"unexpected row in scheme: {row!r}") if _num_remaining_bytes_to_read(tlv_record_fd) > 0: