1
0

lnmsg: implement tests from BOLT-01

This commit is contained in:
SomberNight
2020-03-15 04:56:58 +01:00
parent f353e6d55c
commit 85d7a13360
2 changed files with 205 additions and 27 deletions

View File

@@ -5,24 +5,13 @@ from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional
from collections import OrderedDict
class MalformedMsg(Exception):
pass
class UnknownMsgFieldType(MalformedMsg):
pass
class UnexpectedEndOfStream(MalformedMsg):
pass
class FieldEncodingNotMinimal(MalformedMsg):
pass
class UnknownMandatoryTLVRecordType(MalformedMsg):
pass
class MalformedMsg(Exception): pass
class UnknownMsgFieldType(MalformedMsg): pass
class UnexpectedEndOfStream(MalformedMsg): pass
class FieldEncodingNotMinimal(MalformedMsg): pass
class UnknownMandatoryTLVRecordType(MalformedMsg): pass
class MsgTrailingGarbage(MalformedMsg): pass
class MsgInvalidFieldOrder(MalformedMsg): pass
def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int:
@@ -38,7 +27,7 @@ def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left")
def bigsize_from_int(i: int) -> bytes:
def write_bigsize_int(i: int) -> bytes:
assert i >= 0, i
if i < 0xfd:
return int.to_bytes(i, length=1, byteorder="big", signed=False)
@@ -50,7 +39,7 @@ def bigsize_from_int(i: int) -> bytes:
return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False)
def read_int_from_bigsize(fd: io.BytesIO) -> Optional[int]:
def read_bigsize_int(fd: io.BytesIO) -> Optional[int]:
try:
first = fd.read(1)[0]
except IndexError:
@@ -59,13 +48,22 @@ def read_int_from_bigsize(fd: io.BytesIO) -> Optional[int]:
return first
elif first == 0xfd:
_assert_can_read_at_least_n_bytes(fd, 2)
return int.from_bytes(fd.read(2), byteorder="big", signed=False)
val = int.from_bytes(fd.read(2), byteorder="big", signed=False)
if not (0xfd <= val < 0x1_0000):
raise FieldEncodingNotMinimal()
return val
elif first == 0xfe:
_assert_can_read_at_least_n_bytes(fd, 4)
return int.from_bytes(fd.read(4), byteorder="big", signed=False)
val = int.from_bytes(fd.read(4), byteorder="big", signed=False)
if not (0x1_0000 <= val < 0x1_0000_0000):
raise FieldEncodingNotMinimal()
return val
elif first == 0xff:
_assert_can_read_at_least_n_bytes(fd, 8)
return int.from_bytes(fd.read(8), byteorder="big", signed=False)
val = int.from_bytes(fd.read(8), byteorder="big", signed=False)
if not (0x1_0000_0000 <= val):
raise FieldEncodingNotMinimal()
return val
raise Exception()
@@ -112,7 +110,7 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> U
return int.from_bytes(raw, byteorder="big", signed=False)
elif field_type == 'varint':
assert count == 1, count
val = read_int_from_bigsize(fd)
val = read_bigsize_int(fd)
if val is None:
raise UnexpectedEndOfStream()
return val
@@ -183,7 +181,7 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str],
elif field_type == 'varint':
assert count == 1, count
if isinstance(value, int):
value = bigsize_from_int(value)
value = write_bigsize_int(value)
if not isinstance(value, (bytes, bytearray)):
raise Exception(f"can only write bytes into fd. got: {value!r}")
nbytes_written = fd.write(value)
@@ -347,7 +345,8 @@ class LNSerializer:
while _num_remaining_bytes_to_read(fd) > 0:
tlv_record_type, tlv_record_val = _read_tlv_record(fd=fd)
if not (tlv_record_type > last_seen_tlv_record_type):
raise MalformedMsg("TLV records must be monotonically increasing by type")
raise MsgInvalidFieldOrder(f"TLV records must be monotonically increasing by type. "
f"cur: {tlv_record_type}. prev: {last_seen_tlv_record_type}")
last_seen_tlv_record_type = tlv_record_type
try:
scheme = scheme_map[tlv_record_type]
@@ -382,7 +381,7 @@ class LNSerializer:
else:
raise Exception(f"unexpected row in scheme: {row!r}")
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
raise MalformedMsg(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")
raise MsgTrailingGarbage(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")
return parsed
def encode_msg(self, msg_type: str, **kwargs) -> bytes: