lnmsg: initial TLV implementation
This commit is contained in:
@@ -2,6 +2,7 @@ import os
|
|||||||
import csv
|
import csv
|
||||||
import io
|
import io
|
||||||
from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional
|
from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
class MalformedMsg(Exception):
|
class MalformedMsg(Exception):
|
||||||
@@ -16,12 +17,56 @@ class UnexpectedEndOfStream(MalformedMsg):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
|
class FieldEncodingNotMinimal(MalformedMsg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownMandatoryTLVRecordType(MalformedMsg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int:
|
||||||
cur_pos = fd.tell()
|
cur_pos = fd.tell()
|
||||||
end_pos = fd.seek(0, io.SEEK_END)
|
end_pos = fd.seek(0, io.SEEK_END)
|
||||||
fd.seek(cur_pos)
|
fd.seek(cur_pos)
|
||||||
if end_pos - cur_pos < n:
|
return end_pos - cur_pos
|
||||||
raise UnexpectedEndOfStream(f"cur_pos={cur_pos}. end_pos={end_pos}. wants to read: {n}")
|
|
||||||
|
|
||||||
|
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
|
||||||
|
nremaining = _num_remaining_bytes_to_read(fd)
|
||||||
|
if nremaining < n:
|
||||||
|
raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left")
|
||||||
|
|
||||||
|
|
||||||
|
def bigsize_from_int(i: int) -> bytes:
|
||||||
|
assert i >= 0, i
|
||||||
|
if i < 0xfd:
|
||||||
|
return int.to_bytes(i, length=1, byteorder="big", signed=False)
|
||||||
|
elif i < 0x1_0000:
|
||||||
|
return b"\xfd" + int.to_bytes(i, length=2, byteorder="big", signed=False)
|
||||||
|
elif i < 0x1_0000_0000:
|
||||||
|
return b"\xfe" + int.to_bytes(i, length=4, byteorder="big", signed=False)
|
||||||
|
else:
|
||||||
|
return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False)
|
||||||
|
|
||||||
|
|
||||||
|
def read_int_from_bigsize(fd: io.BytesIO) -> Optional[int]:
|
||||||
|
try:
|
||||||
|
first = fd.read(1)[0]
|
||||||
|
except IndexError:
|
||||||
|
return None # end of file
|
||||||
|
if first < 0xfd:
|
||||||
|
return first
|
||||||
|
elif first == 0xfd:
|
||||||
|
_assert_can_read_at_least_n_bytes(fd, 2)
|
||||||
|
return int.from_bytes(fd.read(2), byteorder="big", signed=False)
|
||||||
|
elif first == 0xfe:
|
||||||
|
_assert_can_read_at_least_n_bytes(fd, 4)
|
||||||
|
return int.from_bytes(fd.read(4), byteorder="big", signed=False)
|
||||||
|
elif first == 0xff:
|
||||||
|
_assert_can_read_at_least_n_bytes(fd, 8)
|
||||||
|
return int.from_bytes(fd.read(8), byteorder="big", signed=False)
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
|
|
||||||
def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes, int]:
|
def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes, int]:
|
||||||
@@ -32,22 +77,36 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes,
|
|||||||
type_len = None
|
type_len = None
|
||||||
if field_type == 'byte':
|
if field_type == 'byte':
|
||||||
type_len = 1
|
type_len = 1
|
||||||
elif field_type == 'u16':
|
elif field_type in ('u16', 'u32', 'u64'):
|
||||||
type_len = 2
|
if field_type == 'u16':
|
||||||
|
type_len = 2
|
||||||
|
elif field_type == 'u32':
|
||||||
|
type_len = 4
|
||||||
|
else:
|
||||||
|
assert field_type == 'u64'
|
||||||
|
type_len = 8
|
||||||
assert count == 1, count
|
assert count == 1, count
|
||||||
_assert_can_read_at_least_n_bytes(fd, type_len)
|
_assert_can_read_at_least_n_bytes(fd, type_len)
|
||||||
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False)
|
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False)
|
||||||
elif field_type == 'u32':
|
elif field_type in ('tu16', 'tu32', 'tu64'):
|
||||||
type_len = 4
|
if field_type == 'tu16':
|
||||||
|
type_len = 2
|
||||||
|
elif field_type == 'tu32':
|
||||||
|
type_len = 4
|
||||||
|
else:
|
||||||
|
assert field_type == 'tu64'
|
||||||
|
type_len = 8
|
||||||
assert count == 1, count
|
assert count == 1, count
|
||||||
_assert_can_read_at_least_n_bytes(fd, type_len)
|
raw = fd.read(type_len)
|
||||||
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False)
|
if len(raw) > 0 and raw[0] == 0x00:
|
||||||
elif field_type == 'u64':
|
raise FieldEncodingNotMinimal()
|
||||||
type_len = 8
|
return int.from_bytes(raw, byteorder="big", signed=False)
|
||||||
|
elif field_type == 'varint':
|
||||||
assert count == 1, count
|
assert count == 1, count
|
||||||
_assert_can_read_at_least_n_bytes(fd, type_len)
|
val = read_int_from_bigsize(fd)
|
||||||
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False)
|
if val is None:
|
||||||
# TODO tu16/tu32/tu64
|
raise UnexpectedEndOfStream()
|
||||||
|
return val
|
||||||
elif field_type == 'chain_hash':
|
elif field_type == 'chain_hash':
|
||||||
type_len = 32
|
type_len = 32
|
||||||
elif field_type == 'channel_id':
|
elif field_type == 'channel_id':
|
||||||
@@ -82,7 +141,35 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int,
|
|||||||
type_len = 4
|
type_len = 4
|
||||||
elif field_type == 'u64':
|
elif field_type == 'u64':
|
||||||
type_len = 8
|
type_len = 8
|
||||||
# TODO tu16/tu32/tu64
|
elif field_type in ('tu16', 'tu32', 'tu64'):
|
||||||
|
if field_type == 'tu16':
|
||||||
|
type_len = 2
|
||||||
|
elif field_type == 'tu32':
|
||||||
|
type_len = 4
|
||||||
|
else:
|
||||||
|
assert field_type == 'tu64'
|
||||||
|
type_len = 8
|
||||||
|
assert count == 1, count
|
||||||
|
if isinstance(value, int):
|
||||||
|
value = int.to_bytes(value, length=type_len, byteorder="big", signed=False)
|
||||||
|
if not isinstance(value, (bytes, bytearray)):
|
||||||
|
raise Exception(f"can only write bytes into fd. got: {value!r}")
|
||||||
|
while len(value) > 0 and value[0] == 0x00:
|
||||||
|
value = value[1:]
|
||||||
|
nbytes_written = fd.write(value)
|
||||||
|
if nbytes_written != len(value):
|
||||||
|
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
|
||||||
|
return
|
||||||
|
elif field_type == 'varint':
|
||||||
|
assert count == 1, count
|
||||||
|
if isinstance(value, int):
|
||||||
|
value = bigsize_from_int(value)
|
||||||
|
if not isinstance(value, (bytes, bytearray)):
|
||||||
|
raise Exception(f"can only write bytes into fd. got: {value!r}")
|
||||||
|
nbytes_written = fd.write(value)
|
||||||
|
if nbytes_written != len(value):
|
||||||
|
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
|
||||||
|
return
|
||||||
elif field_type == 'chain_hash':
|
elif field_type == 'chain_hash':
|
||||||
type_len = 32
|
type_len = 32
|
||||||
elif field_type == 'channel_id':
|
elif field_type == 'channel_id':
|
||||||
@@ -109,16 +196,55 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int,
|
|||||||
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
|
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
|
||||||
|
|
||||||
|
|
||||||
|
def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]:
|
||||||
|
if not fd: raise Exception()
|
||||||
|
tlv_type = _read_field(fd=fd, field_type="varint", count=1)
|
||||||
|
tlv_len = _read_field(fd=fd, field_type="varint", count=1)
|
||||||
|
tlv_val = _read_field(fd=fd, field_type="byte", count=tlv_len)
|
||||||
|
return tlv_type, tlv_val
|
||||||
|
|
||||||
|
|
||||||
|
def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None:
|
||||||
|
if not fd: raise Exception()
|
||||||
|
tlv_len = len(tlv_val)
|
||||||
|
_write_field(fd=fd, field_type="varint", count=1, value=tlv_type)
|
||||||
|
_write_field(fd=fd, field_type="varint", count=1, value=tlv_len)
|
||||||
|
_write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_field_count(field_count_str: str, *, vars_dict: dict) -> int:
|
||||||
|
if field_count_str == "":
|
||||||
|
field_count = 1
|
||||||
|
elif field_count_str == "...":
|
||||||
|
raise NotImplementedError() # TODO...
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
field_count = int(field_count_str)
|
||||||
|
except ValueError:
|
||||||
|
field_count = vars_dict[field_count_str]
|
||||||
|
if isinstance(field_count, (bytes, bytearray)):
|
||||||
|
field_count = int.from_bytes(field_count, byteorder="big")
|
||||||
|
assert isinstance(field_count, int)
|
||||||
|
return field_count
|
||||||
|
|
||||||
|
|
||||||
class LNSerializer:
|
class LNSerializer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
# TODO msg_type could be 'int' everywhere...
|
||||||
self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]]
|
self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]]
|
||||||
self.msg_type_from_name = {} # type: Dict[str, bytes]
|
self.msg_type_from_name = {} # type: Dict[str, bytes]
|
||||||
|
|
||||||
|
self.in_tlv_stream_get_tlv_record_scheme_from_type = {} # type: Dict[str, Dict[int, List[Sequence[str]]]]
|
||||||
|
self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]]
|
||||||
|
self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]]
|
||||||
|
|
||||||
path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv")
|
path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv")
|
||||||
with open(path, newline='') as f:
|
with open(path, newline='') as f:
|
||||||
csvreader = csv.reader(f)
|
csvreader = csv.reader(f)
|
||||||
for row in csvreader:
|
for row in csvreader:
|
||||||
#print(f">>> {row!r}")
|
#print(f">>> {row!r}")
|
||||||
if row[0] == "msgtype":
|
if row[0] == "msgtype":
|
||||||
|
# msgtype,<msgname>,<value>[,<option>]
|
||||||
msg_type_name = row[1]
|
msg_type_name = row[1]
|
||||||
msg_type_int = int(row[2])
|
msg_type_int = int(row[2])
|
||||||
msg_type_bytes = msg_type_int.to_bytes(2, 'big')
|
msg_type_bytes = msg_type_int.to_bytes(2, 'big')
|
||||||
@@ -128,11 +254,106 @@ class LNSerializer:
|
|||||||
self.msg_scheme_from_type[msg_type_bytes] = [tuple(row)]
|
self.msg_scheme_from_type[msg_type_bytes] = [tuple(row)]
|
||||||
self.msg_type_from_name[msg_type_name] = msg_type_bytes
|
self.msg_type_from_name[msg_type_name] = msg_type_bytes
|
||||||
elif row[0] == "msgdata":
|
elif row[0] == "msgdata":
|
||||||
|
# msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
|
||||||
assert msg_type_name == row[1]
|
assert msg_type_name == row[1]
|
||||||
self.msg_scheme_from_type[msg_type_bytes].append(tuple(row))
|
self.msg_scheme_from_type[msg_type_bytes].append(tuple(row))
|
||||||
|
elif row[0] == "tlvtype":
|
||||||
|
# tlvtype,<tlvstreamname>,<tlvname>,<value>[,<option>]
|
||||||
|
tlv_stream_name = row[1]
|
||||||
|
tlv_record_name = row[2]
|
||||||
|
tlv_record_type = int(row[3])
|
||||||
|
row[3] = tlv_record_type
|
||||||
|
if tlv_stream_name not in self.in_tlv_stream_get_tlv_record_scheme_from_type:
|
||||||
|
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name] = OrderedDict()
|
||||||
|
self.in_tlv_stream_get_record_type_from_name[tlv_stream_name] = {}
|
||||||
|
self.in_tlv_stream_get_record_name_from_type[tlv_stream_name] = {}
|
||||||
|
assert tlv_record_type not in self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
|
||||||
|
assert tlv_record_name not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
|
||||||
|
assert tlv_record_type not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
|
||||||
|
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type] = [tuple(row)]
|
||||||
|
self.in_tlv_stream_get_record_type_from_name[tlv_stream_name][tlv_record_name] = tlv_record_type
|
||||||
|
self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type] = tlv_record_name
|
||||||
|
if max(self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name].keys()) > tlv_record_type:
|
||||||
|
raise Exception(f"tlv record types must be listed in monotonically increasing order for stream. "
|
||||||
|
f"stream={tlv_stream_name}")
|
||||||
|
elif row[0] == "tlvdata":
|
||||||
|
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
|
||||||
|
assert tlv_stream_name == row[1]
|
||||||
|
assert tlv_record_name == row[2]
|
||||||
|
self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row))
|
||||||
else:
|
else:
|
||||||
pass # TODO
|
pass # TODO
|
||||||
|
|
||||||
|
def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None:
|
||||||
|
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
|
||||||
|
for tlv_record_type, scheme in scheme_map.items(): # note: tlv_record_type is monotonically increasing
|
||||||
|
tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
|
||||||
|
if tlv_record_name not in kwargs:
|
||||||
|
continue
|
||||||
|
with io.BytesIO() as tlv_record_fd:
|
||||||
|
for row in scheme:
|
||||||
|
if row[0] == "tlvtype":
|
||||||
|
pass
|
||||||
|
elif row[0] == "tlvdata":
|
||||||
|
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
|
||||||
|
assert tlv_stream_name == row[1]
|
||||||
|
assert tlv_record_name == row[2]
|
||||||
|
field_name = row[3]
|
||||||
|
field_type = row[4]
|
||||||
|
field_count_str = row[5]
|
||||||
|
field_count = _resolve_field_count(field_count_str, vars_dict=kwargs[tlv_record_name])
|
||||||
|
field_value = kwargs[tlv_record_name][field_name]
|
||||||
|
_write_field(fd=tlv_record_fd,
|
||||||
|
field_type=field_type,
|
||||||
|
count=field_count,
|
||||||
|
value=field_value)
|
||||||
|
else:
|
||||||
|
pass # TODO
|
||||||
|
_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
|
||||||
|
|
||||||
|
def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str, Dict[str, Any]]:
|
||||||
|
parsed = {} # type: Dict[str, Dict[str, Any]]
|
||||||
|
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
|
||||||
|
last_seen_tlv_record_type = -1 # type: int
|
||||||
|
while _num_remaining_bytes_to_read(fd) > 0:
|
||||||
|
tlv_record_type, tlv_record_val = _read_tlv_record(fd=fd)
|
||||||
|
if not (tlv_record_type > last_seen_tlv_record_type):
|
||||||
|
raise MalformedMsg("TLV records must be monotonically increasing by type")
|
||||||
|
last_seen_tlv_record_type = tlv_record_type
|
||||||
|
try:
|
||||||
|
scheme = scheme_map[tlv_record_type]
|
||||||
|
except KeyError:
|
||||||
|
if tlv_record_type % 2 == 0:
|
||||||
|
# unknown "even" type: hard fail
|
||||||
|
raise UnknownMandatoryTLVRecordType(f"{tlv_stream_name}/{tlv_record_type}") from None
|
||||||
|
else:
|
||||||
|
# unknown "odd" type: skip it
|
||||||
|
continue
|
||||||
|
tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
|
||||||
|
parsed[tlv_record_name] = {}
|
||||||
|
with io.BytesIO(tlv_record_val) as tlv_record_fd:
|
||||||
|
for row in scheme:
|
||||||
|
#print(f"row: {row!r}")
|
||||||
|
if row[0] == "tlvtype":
|
||||||
|
pass
|
||||||
|
elif row[0] == "tlvdata":
|
||||||
|
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
|
||||||
|
assert tlv_stream_name == row[1]
|
||||||
|
assert tlv_record_name == row[2]
|
||||||
|
field_name = row[3]
|
||||||
|
field_type = row[4]
|
||||||
|
field_count_str = row[5]
|
||||||
|
field_count = _resolve_field_count(field_count_str, vars_dict=parsed[tlv_record_name])
|
||||||
|
#print(f">> count={field_count}. parsed={parsed}")
|
||||||
|
parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
|
||||||
|
field_type=field_type,
|
||||||
|
count=field_count)
|
||||||
|
else:
|
||||||
|
pass # TODO
|
||||||
|
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
|
||||||
|
raise MalformedMsg(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")
|
||||||
|
return parsed
|
||||||
|
|
||||||
def encode_msg(self, msg_type: str, **kwargs) -> bytes:
|
def encode_msg(self, msg_type: str, **kwargs) -> bytes:
|
||||||
"""
|
"""
|
||||||
Encode kwargs into a Lightning message (bytes)
|
Encode kwargs into a Lightning message (bytes)
|
||||||
@@ -147,20 +368,12 @@ class LNSerializer:
|
|||||||
if row[0] == "msgtype":
|
if row[0] == "msgtype":
|
||||||
pass
|
pass
|
||||||
elif row[0] == "msgdata":
|
elif row[0] == "msgdata":
|
||||||
|
# msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
|
||||||
field_name = row[2]
|
field_name = row[2]
|
||||||
field_type = row[3]
|
field_type = row[3]
|
||||||
field_count_str = row[4]
|
field_count_str = row[4]
|
||||||
#print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}")
|
#print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}")
|
||||||
if field_count_str == "":
|
field_count = _resolve_field_count(field_count_str, vars_dict=kwargs)
|
||||||
field_count = 1
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
field_count = int(field_count_str)
|
|
||||||
except ValueError:
|
|
||||||
field_count = kwargs[field_count_str]
|
|
||||||
if isinstance(field_count, (bytes, bytearray)):
|
|
||||||
field_count = int.from_bytes(field_count, byteorder="big")
|
|
||||||
assert isinstance(field_count, int)
|
|
||||||
try:
|
try:
|
||||||
field_value = kwargs[field_name]
|
field_value = kwargs[field_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@@ -205,14 +418,7 @@ class LNSerializer:
|
|||||||
field_name = row[2]
|
field_name = row[2]
|
||||||
field_type = row[3]
|
field_type = row[3]
|
||||||
field_count_str = row[4]
|
field_count_str = row[4]
|
||||||
if field_count_str == "":
|
field_count = _resolve_field_count(field_count_str, vars_dict=parsed)
|
||||||
field_count = 1
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
field_count = int(field_count_str)
|
|
||||||
except ValueError:
|
|
||||||
field_count = parsed[field_count_str]
|
|
||||||
assert isinstance(field_count, int)
|
|
||||||
#print(f">> count={field_count}. parsed={parsed}")
|
#print(f">> count={field_count}. parsed={parsed}")
|
||||||
try:
|
try:
|
||||||
parsed[field_name] = _read_field(fd=fd,
|
parsed[field_name] = _read_field(fd=fd,
|
||||||
|
|||||||
Reference in New Issue
Block a user