lnmsg: implement tests from BOLT-01
This commit is contained in:
@@ -5,24 +5,13 @@ from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class MalformedMsg(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownMsgFieldType(MalformedMsg):
|
||||
pass
|
||||
|
||||
|
||||
class UnexpectedEndOfStream(MalformedMsg):
|
||||
pass
|
||||
|
||||
|
||||
class FieldEncodingNotMinimal(MalformedMsg):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownMandatoryTLVRecordType(MalformedMsg):
|
||||
pass
|
||||
class MalformedMsg(Exception): pass
|
||||
class UnknownMsgFieldType(MalformedMsg): pass
|
||||
class UnexpectedEndOfStream(MalformedMsg): pass
|
||||
class FieldEncodingNotMinimal(MalformedMsg): pass
|
||||
class UnknownMandatoryTLVRecordType(MalformedMsg): pass
|
||||
class MsgTrailingGarbage(MalformedMsg): pass
|
||||
class MsgInvalidFieldOrder(MalformedMsg): pass
|
||||
|
||||
|
||||
def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int:
|
||||
@@ -38,7 +27,7 @@ def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
|
||||
raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left")
|
||||
|
||||
|
||||
def bigsize_from_int(i: int) -> bytes:
|
||||
def write_bigsize_int(i: int) -> bytes:
|
||||
assert i >= 0, i
|
||||
if i < 0xfd:
|
||||
return int.to_bytes(i, length=1, byteorder="big", signed=False)
|
||||
@@ -50,7 +39,7 @@ def bigsize_from_int(i: int) -> bytes:
|
||||
return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False)
|
||||
|
||||
|
||||
def read_int_from_bigsize(fd: io.BytesIO) -> Optional[int]:
|
||||
def read_bigsize_int(fd: io.BytesIO) -> Optional[int]:
|
||||
try:
|
||||
first = fd.read(1)[0]
|
||||
except IndexError:
|
||||
@@ -59,13 +48,22 @@ def read_int_from_bigsize(fd: io.BytesIO) -> Optional[int]:
|
||||
return first
|
||||
elif first == 0xfd:
|
||||
_assert_can_read_at_least_n_bytes(fd, 2)
|
||||
return int.from_bytes(fd.read(2), byteorder="big", signed=False)
|
||||
val = int.from_bytes(fd.read(2), byteorder="big", signed=False)
|
||||
if not (0xfd <= val < 0x1_0000):
|
||||
raise FieldEncodingNotMinimal()
|
||||
return val
|
||||
elif first == 0xfe:
|
||||
_assert_can_read_at_least_n_bytes(fd, 4)
|
||||
return int.from_bytes(fd.read(4), byteorder="big", signed=False)
|
||||
val = int.from_bytes(fd.read(4), byteorder="big", signed=False)
|
||||
if not (0x1_0000 <= val < 0x1_0000_0000):
|
||||
raise FieldEncodingNotMinimal()
|
||||
return val
|
||||
elif first == 0xff:
|
||||
_assert_can_read_at_least_n_bytes(fd, 8)
|
||||
return int.from_bytes(fd.read(8), byteorder="big", signed=False)
|
||||
val = int.from_bytes(fd.read(8), byteorder="big", signed=False)
|
||||
if not (0x1_0000_0000 <= val):
|
||||
raise FieldEncodingNotMinimal()
|
||||
return val
|
||||
raise Exception()
|
||||
|
||||
|
||||
@@ -112,7 +110,7 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U
|
||||
return int.from_bytes(raw, byteorder="big", signed=False)
|
||||
elif field_type == 'varint':
|
||||
assert count == 1, count
|
||||
val = read_int_from_bigsize(fd)
|
||||
val = read_bigsize_int(fd)
|
||||
if val is None:
|
||||
raise UnexpectedEndOfStream()
|
||||
return val
|
||||
@@ -183,7 +181,7 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str],
|
||||
elif field_type == 'varint':
|
||||
assert count == 1, count
|
||||
if isinstance(value, int):
|
||||
value = bigsize_from_int(value)
|
||||
value = write_bigsize_int(value)
|
||||
if not isinstance(value, (bytes, bytearray)):
|
||||
raise Exception(f"can only write bytes into fd. got: {value!r}")
|
||||
nbytes_written = fd.write(value)
|
||||
@@ -347,7 +345,8 @@ class LNSerializer:
|
||||
while _num_remaining_bytes_to_read(fd) > 0:
|
||||
tlv_record_type, tlv_record_val = _read_tlv_record(fd=fd)
|
||||
if not (tlv_record_type > last_seen_tlv_record_type):
|
||||
raise MalformedMsg("TLV records must be monotonically increasing by type")
|
||||
raise MsgInvalidFieldOrder(f"TLV records must be monotonically increasing by type. "
|
||||
f"cur: {tlv_record_type}. prev: {last_seen_tlv_record_type}")
|
||||
last_seen_tlv_record_type = tlv_record_type
|
||||
try:
|
||||
scheme = scheme_map[tlv_record_type]
|
||||
@@ -382,7 +381,7 @@ class LNSerializer:
|
||||
else:
|
||||
raise Exception(f"unexpected row in scheme: {row!r}")
|
||||
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
|
||||
raise MalformedMsg(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")
|
||||
raise MsgTrailingGarbage(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")
|
||||
return parsed
|
||||
|
||||
def encode_msg(self, msg_type: str, **kwargs) -> bytes:
|
||||
|
||||
Reference in New Issue
Block a user