lnmsg: initial TLV implementation
This commit is contained in:
@@ -2,6 +2,7 @@ import os
|
||||
import csv
|
||||
import io
|
||||
from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class MalformedMsg(Exception):
|
||||
@@ -16,12 +17,56 @@ class UnexpectedEndOfStream(MalformedMsg):
|
||||
pass
|
||||
|
||||
|
||||
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
|
||||
class FieldEncodingNotMinimal(MalformedMsg):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownMandatoryTLVRecordType(MalformedMsg):
|
||||
pass
|
||||
|
||||
|
||||
def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int:
|
||||
cur_pos = fd.tell()
|
||||
end_pos = fd.seek(0, io.SEEK_END)
|
||||
fd.seek(cur_pos)
|
||||
if end_pos - cur_pos < n:
|
||||
raise UnexpectedEndOfStream(f"cur_pos={cur_pos}. end_pos={end_pos}. wants to read: {n}")
|
||||
return end_pos - cur_pos
|
||||
|
||||
|
||||
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
|
||||
nremaining = _num_remaining_bytes_to_read(fd)
|
||||
if nremaining < n:
|
||||
raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left")
|
||||
|
||||
|
||||
def bigsize_from_int(i: int) -> bytes:
|
||||
assert i >= 0, i
|
||||
if i < 0xfd:
|
||||
return int.to_bytes(i, length=1, byteorder="big", signed=False)
|
||||
elif i < 0x1_0000:
|
||||
return b"\xfd" + int.to_bytes(i, length=2, byteorder="big", signed=False)
|
||||
elif i < 0x1_0000_0000:
|
||||
return b"\xfe" + int.to_bytes(i, length=4, byteorder="big", signed=False)
|
||||
else:
|
||||
return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False)
|
||||
|
||||
|
||||
def read_int_from_bigsize(fd: io.BytesIO) -> Optional[int]:
|
||||
try:
|
||||
first = fd.read(1)[0]
|
||||
except IndexError:
|
||||
return None # end of file
|
||||
if first < 0xfd:
|
||||
return first
|
||||
elif first == 0xfd:
|
||||
_assert_can_read_at_least_n_bytes(fd, 2)
|
||||
return int.from_bytes(fd.read(2), byteorder="big", signed=False)
|
||||
elif first == 0xfe:
|
||||
_assert_can_read_at_least_n_bytes(fd, 4)
|
||||
return int.from_bytes(fd.read(4), byteorder="big", signed=False)
|
||||
elif first == 0xff:
|
||||
_assert_can_read_at_least_n_bytes(fd, 8)
|
||||
return int.from_bytes(fd.read(8), byteorder="big", signed=False)
|
||||
raise Exception()
|
||||
|
||||
|
||||
def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes, int]:
|
||||
@@ -32,22 +77,36 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes,
|
||||
type_len = None
|
||||
if field_type == 'byte':
|
||||
type_len = 1
|
||||
elif field_type == 'u16':
|
||||
type_len = 2
|
||||
elif field_type in ('u16', 'u32', 'u64'):
|
||||
if field_type == 'u16':
|
||||
type_len = 2
|
||||
elif field_type == 'u32':
|
||||
type_len = 4
|
||||
else:
|
||||
assert field_type == 'u64'
|
||||
type_len = 8
|
||||
assert count == 1, count
|
||||
_assert_can_read_at_least_n_bytes(fd, type_len)
|
||||
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False)
|
||||
elif field_type == 'u32':
|
||||
type_len = 4
|
||||
elif field_type in ('tu16', 'tu32', 'tu64'):
|
||||
if field_type == 'tu16':
|
||||
type_len = 2
|
||||
elif field_type == 'tu32':
|
||||
type_len = 4
|
||||
else:
|
||||
assert field_type == 'tu64'
|
||||
type_len = 8
|
||||
assert count == 1, count
|
||||
_assert_can_read_at_least_n_bytes(fd, type_len)
|
||||
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False)
|
||||
elif field_type == 'u64':
|
||||
type_len = 8
|
||||
raw = fd.read(type_len)
|
||||
if len(raw) > 0 and raw[0] == 0x00:
|
||||
raise FieldEncodingNotMinimal()
|
||||
return int.from_bytes(raw, byteorder="big", signed=False)
|
||||
elif field_type == 'varint':
|
||||
assert count == 1, count
|
||||
_assert_can_read_at_least_n_bytes(fd, type_len)
|
||||
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False)
|
||||
# TODO tu16/tu32/tu64
|
||||
val = read_int_from_bigsize(fd)
|
||||
if val is None:
|
||||
raise UnexpectedEndOfStream()
|
||||
return val
|
||||
elif field_type == 'chain_hash':
|
||||
type_len = 32
|
||||
elif field_type == 'channel_id':
|
||||
@@ -82,7 +141,35 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int,
|
||||
type_len = 4
|
||||
elif field_type == 'u64':
|
||||
type_len = 8
|
||||
# TODO tu16/tu32/tu64
|
||||
elif field_type in ('tu16', 'tu32', 'tu64'):
|
||||
if field_type == 'tu16':
|
||||
type_len = 2
|
||||
elif field_type == 'tu32':
|
||||
type_len = 4
|
||||
else:
|
||||
assert field_type == 'tu64'
|
||||
type_len = 8
|
||||
assert count == 1, count
|
||||
if isinstance(value, int):
|
||||
value = int.to_bytes(value, length=type_len, byteorder="big", signed=False)
|
||||
if not isinstance(value, (bytes, bytearray)):
|
||||
raise Exception(f"can only write bytes into fd. got: {value!r}")
|
||||
while len(value) > 0 and value[0] == 0x00:
|
||||
value = value[1:]
|
||||
nbytes_written = fd.write(value)
|
||||
if nbytes_written != len(value):
|
||||
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
|
||||
return
|
||||
elif field_type == 'varint':
|
||||
assert count == 1, count
|
||||
if isinstance(value, int):
|
||||
value = bigsize_from_int(value)
|
||||
if not isinstance(value, (bytes, bytearray)):
|
||||
raise Exception(f"can only write bytes into fd. got: {value!r}")
|
||||
nbytes_written = fd.write(value)
|
||||
if nbytes_written != len(value):
|
||||
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
|
||||
return
|
||||
elif field_type == 'chain_hash':
|
||||
type_len = 32
|
||||
elif field_type == 'channel_id':
|
||||
@@ -109,16 +196,55 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int,
|
||||
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
|
||||
|
||||
|
||||
def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]:
|
||||
if not fd: raise Exception()
|
||||
tlv_type = _read_field(fd=fd, field_type="varint", count=1)
|
||||
tlv_len = _read_field(fd=fd, field_type="varint", count=1)
|
||||
tlv_val = _read_field(fd=fd, field_type="byte", count=tlv_len)
|
||||
return tlv_type, tlv_val
|
||||
|
||||
|
||||
def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None:
|
||||
if not fd: raise Exception()
|
||||
tlv_len = len(tlv_val)
|
||||
_write_field(fd=fd, field_type="varint", count=1, value=tlv_type)
|
||||
_write_field(fd=fd, field_type="varint", count=1, value=tlv_len)
|
||||
_write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
|
||||
|
||||
|
||||
def _resolve_field_count(field_count_str: str, *, vars_dict: dict) -> int:
|
||||
if field_count_str == "":
|
||||
field_count = 1
|
||||
elif field_count_str == "...":
|
||||
raise NotImplementedError() # TODO...
|
||||
else:
|
||||
try:
|
||||
field_count = int(field_count_str)
|
||||
except ValueError:
|
||||
field_count = vars_dict[field_count_str]
|
||||
if isinstance(field_count, (bytes, bytearray)):
|
||||
field_count = int.from_bytes(field_count, byteorder="big")
|
||||
assert isinstance(field_count, int)
|
||||
return field_count
|
||||
|
||||
|
||||
class LNSerializer:
|
||||
def __init__(self):
|
||||
# TODO msg_type could be 'int' everywhere...
|
||||
self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]]
|
||||
self.msg_type_from_name = {} # type: Dict[str, bytes]
|
||||
|
||||
self.in_tlv_stream_get_tlv_record_scheme_from_type = {} # type: Dict[str, Dict[int, List[Sequence[str]]]]
|
||||
self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]]
|
||||
self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]]
|
||||
|
||||
path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv")
|
||||
with open(path, newline='') as f:
|
||||
csvreader = csv.reader(f)
|
||||
for row in csvreader:
|
||||
#print(f">>> {row!r}")
|
||||
if row[0] == "msgtype":
|
||||
# msgtype,<msgname>,<value>[,<option>]
|
||||
msg_type_name = row[1]
|
||||
msg_type_int = int(row[2])
|
||||
msg_type_bytes = msg_type_int.to_bytes(2, 'big')
|
||||
@@ -128,11 +254,106 @@ class LNSerializer:
|
||||
self.msg_scheme_from_type[msg_type_bytes] = [tuple(row)]
|
||||
self.msg_type_from_name[msg_type_name] = msg_type_bytes
|
||||
elif row[0] == "msgdata":
|
||||
# msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
|
||||
assert msg_type_name == row[1]
|
||||
self.msg_scheme_from_type[msg_type_bytes].append(tuple(row))
|
||||
elif row[0] == "tlvtype":
|
||||
# tlvtype,<tlvstreamname>,<tlvname>,<value>[,<option>]
|
||||
tlv_stream_name = row[1]
|
||||
tlv_record_name = row[2]
|
||||
tlv_record_type = int(row[3])
|
||||
row[3] = tlv_record_type
|
||||
if tlv_stream_name not in self.in_tlv_stream_get_tlv_record_scheme_from_type:
|
||||
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name] = OrderedDict()
|
||||
self.in_tlv_stream_get_record_type_from_name[tlv_stream_name] = {}
|
||||
self.in_tlv_stream_get_record_name_from_type[tlv_stream_name] = {}
|
||||
assert tlv_record_type not in self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
|
||||
assert tlv_record_name not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
|
||||
assert tlv_record_type not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
|
||||
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type] = [tuple(row)]
|
||||
self.in_tlv_stream_get_record_type_from_name[tlv_stream_name][tlv_record_name] = tlv_record_type
|
||||
self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type] = tlv_record_name
|
||||
if max(self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name].keys()) > tlv_record_type:
|
||||
raise Exception(f"tlv record types must be listed in monotonically increasing order for stream. "
|
||||
f"stream={tlv_stream_name}")
|
||||
elif row[0] == "tlvdata":
|
||||
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
|
||||
assert tlv_stream_name == row[1]
|
||||
assert tlv_record_name == row[2]
|
||||
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row))
|
||||
else:
|
||||
pass # TODO
|
||||
|
||||
def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None:
|
||||
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
|
||||
for tlv_record_type, scheme in scheme_map.items(): # note: tlv_record_type is monotonically increasing
|
||||
tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
|
||||
if tlv_record_name not in kwargs:
|
||||
continue
|
||||
with io.BytesIO() as tlv_record_fd:
|
||||
for row in scheme:
|
||||
if row[0] == "tlvtype":
|
||||
pass
|
||||
elif row[0] == "tlvdata":
|
||||
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
|
||||
assert tlv_stream_name == row[1]
|
||||
assert tlv_record_name == row[2]
|
||||
field_name = row[3]
|
||||
field_type = row[4]
|
||||
field_count_str = row[5]
|
||||
field_count = _resolve_field_count(field_count_str, vars_dict=kwargs[tlv_record_name])
|
||||
field_value = kwargs[tlv_record_name][field_name]
|
||||
_write_field(fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count,
|
||||
value=field_value)
|
||||
else:
|
||||
pass # TODO
|
||||
_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
|
||||
|
||||
def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str, Dict[str, Any]]:
|
||||
parsed = {} # type: Dict[str, Dict[str, Any]]
|
||||
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
|
||||
last_seen_tlv_record_type = -1 # type: int
|
||||
while _num_remaining_bytes_to_read(fd) > 0:
|
||||
tlv_record_type, tlv_record_val = _read_tlv_record(fd=fd)
|
||||
if not (tlv_record_type > last_seen_tlv_record_type):
|
||||
raise MalformedMsg("TLV records must be monotonically increasing by type")
|
||||
last_seen_tlv_record_type = tlv_record_type
|
||||
try:
|
||||
scheme = scheme_map[tlv_record_type]
|
||||
except KeyError:
|
||||
if tlv_record_type % 2 == 0:
|
||||
# unknown "even" type: hard fail
|
||||
raise UnknownMandatoryTLVRecordType(f"{tlv_stream_name}/{tlv_record_type}") from None
|
||||
else:
|
||||
# unknown "odd" type: skip it
|
||||
continue
|
||||
tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
|
||||
parsed[tlv_record_name] = {}
|
||||
with io.BytesIO(tlv_record_val) as tlv_record_fd:
|
||||
for row in scheme:
|
||||
#print(f"row: {row!r}")
|
||||
if row[0] == "tlvtype":
|
||||
pass
|
||||
elif row[0] == "tlvdata":
|
||||
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
|
||||
assert tlv_stream_name == row[1]
|
||||
assert tlv_record_name == row[2]
|
||||
field_name = row[3]
|
||||
field_type = row[4]
|
||||
field_count_str = row[5]
|
||||
field_count = _resolve_field_count(field_count_str, vars_dict=parsed[tlv_record_name])
|
||||
#print(f">> count={field_count}. parsed={parsed}")
|
||||
parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
|
||||
field_type=field_type,
|
||||
count=field_count)
|
||||
else:
|
||||
pass # TODO
|
||||
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
|
||||
raise MalformedMsg(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")
|
||||
return parsed
|
||||
|
||||
def encode_msg(self, msg_type: str, **kwargs) -> bytes:
|
||||
"""
|
||||
Encode kwargs into a Lightning message (bytes)
|
||||
@@ -147,20 +368,12 @@ class LNSerializer:
|
||||
if row[0] == "msgtype":
|
||||
pass
|
||||
elif row[0] == "msgdata":
|
||||
# msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
|
||||
field_name = row[2]
|
||||
field_type = row[3]
|
||||
field_count_str = row[4]
|
||||
#print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}")
|
||||
if field_count_str == "":
|
||||
field_count = 1
|
||||
else:
|
||||
try:
|
||||
field_count = int(field_count_str)
|
||||
except ValueError:
|
||||
field_count = kwargs[field_count_str]
|
||||
if isinstance(field_count, (bytes, bytearray)):
|
||||
field_count = int.from_bytes(field_count, byteorder="big")
|
||||
assert isinstance(field_count, int)
|
||||
field_count = _resolve_field_count(field_count_str, vars_dict=kwargs)
|
||||
try:
|
||||
field_value = kwargs[field_name]
|
||||
except KeyError:
|
||||
@@ -205,14 +418,7 @@ class LNSerializer:
|
||||
field_name = row[2]
|
||||
field_type = row[3]
|
||||
field_count_str = row[4]
|
||||
if field_count_str == "":
|
||||
field_count = 1
|
||||
else:
|
||||
try:
|
||||
field_count = int(field_count_str)
|
||||
except ValueError:
|
||||
field_count = parsed[field_count_str]
|
||||
assert isinstance(field_count, int)
|
||||
field_count = _resolve_field_count(field_count_str, vars_dict=parsed)
|
||||
#print(f">> count={field_count}. parsed={parsed}")
|
||||
try:
|
||||
parsed[field_name] = _read_field(fd=fd,
|
||||
|
||||
Reference in New Issue
Block a user