lnmsg: rewrite LN msg encoding/decoding
This commit is contained in:
@@ -1,152 +1,225 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Callable, Tuple
|
||||
from collections import OrderedDict
|
||||
import csv
|
||||
import io
|
||||
from typing import Callable, Tuple, Any, Dict, List, Sequence, Union
|
||||
|
||||
def _eval_length_term(x, ma: dict) -> int:
|
||||
"""
|
||||
Evaluate a term of the simple language used
|
||||
to specify lightning message field lengths.
|
||||
|
||||
If `x` is an integer, it is returned as is,
|
||||
otherwise it is treated as a variable and
|
||||
looked up in `ma`.
|
||||
class MalformedMsg(Exception):
|
||||
pass
|
||||
|
||||
If the value in `ma` was no integer, it is
|
||||
assumed big-endian bytes and decoded.
|
||||
|
||||
Returns evaluated result as int
|
||||
"""
|
||||
try:
|
||||
x = int(x)
|
||||
except ValueError:
|
||||
x = ma[x]
|
||||
try:
|
||||
x = int(x)
|
||||
except ValueError:
|
||||
x = int.from_bytes(x, byteorder='big')
|
||||
return x
|
||||
class UnknownMsgFieldType(MalformedMsg):
|
||||
pass
|
||||
|
||||
def _eval_exp_with_ctx(exp, ctx: dict) -> int:
|
||||
"""
|
||||
Evaluate simple mathematical expression given
|
||||
in `exp` with context (variables assigned)
|
||||
from the dict `ctx`.
|
||||
|
||||
Returns evaluated result as int
|
||||
"""
|
||||
exp = str(exp)
|
||||
if "*" in exp:
|
||||
assert "+" not in exp
|
||||
result = 1
|
||||
for term in exp.split("*"):
|
||||
result *= _eval_length_term(term, ctx)
|
||||
return result
|
||||
return sum(_eval_length_term(x, ctx) for x in exp.split("+"))
|
||||
class UnexpectedEndOfStream(MalformedMsg):
|
||||
pass
|
||||
|
||||
def _make_handler(msg_name: str, v: dict) -> Callable[[bytes], Tuple[str, dict]]:
|
||||
"""
|
||||
Generate a message handler function (taking bytes)
|
||||
for message type `msg_name` with specification `v`
|
||||
|
||||
Check lib/lightning.json, `msg_name` could be 'init',
|
||||
and `v` could be
|
||||
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
|
||||
cur_pos = fd.tell()
|
||||
end_pos = fd.seek(0, io.SEEK_END)
|
||||
fd.seek(cur_pos)
|
||||
if end_pos - cur_pos < n:
|
||||
raise UnexpectedEndOfStream(f"cur_pos={cur_pos}. end_pos={end_pos}. wants to read: {n}")
|
||||
|
||||
{ type: 16, payload: { 'gflen': ..., ... }, ... }
|
||||
|
||||
Returns function taking bytes
|
||||
"""
|
||||
def handler(data: bytes) -> Tuple[str, dict]:
|
||||
ma = {} # map of field name -> field data; after parsing msg
|
||||
pos = 0
|
||||
for fieldname in v["payload"]:
|
||||
poslenMap = v["payload"][fieldname]
|
||||
if "feature" in poslenMap and pos == len(data):
|
||||
continue
|
||||
#assert pos == _eval_exp_with_ctx(poslenMap["position"], ma) # this assert is expensive...
|
||||
length = poslenMap["length"]
|
||||
length = _eval_exp_with_ctx(length, ma)
|
||||
ma[fieldname] = data[pos:pos+length]
|
||||
pos += length
|
||||
# BOLT-01: "MUST ignore any additional data within a message beyond the length that it expects for that type."
|
||||
assert pos <= len(data), (msg_name, pos, len(data))
|
||||
return msg_name, ma
|
||||
return handler
|
||||
# TODO return int when it makes sense
|
||||
def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> bytes:
|
||||
if not fd: raise Exception()
|
||||
assert isinstance(count, int) and count >= 0, f"{count!r} must be non-neg int"
|
||||
if count == 0:
|
||||
return b""
|
||||
type_len = None
|
||||
if field_type == 'byte':
|
||||
type_len = 1
|
||||
elif field_type == 'u16':
|
||||
type_len = 2
|
||||
elif field_type == 'u32':
|
||||
type_len = 4
|
||||
elif field_type == 'u64':
|
||||
type_len = 8
|
||||
# TODO tu16/tu32/tu64
|
||||
elif field_type == 'chain_hash':
|
||||
type_len = 32
|
||||
elif field_type == 'channel_id':
|
||||
type_len = 32
|
||||
elif field_type == 'sha256':
|
||||
type_len = 32
|
||||
elif field_type == 'signature':
|
||||
type_len = 64
|
||||
elif field_type == 'point':
|
||||
type_len = 33
|
||||
elif field_type == 'short_channel_id':
|
||||
type_len = 8
|
||||
if type_len is None:
|
||||
raise UnknownMsgFieldType(f"unexpected field type: {field_type!r}")
|
||||
total_len = count * type_len
|
||||
_assert_can_read_at_least_n_bytes(fd, total_len)
|
||||
return fd.read(total_len)
|
||||
|
||||
|
||||
def _write_field(*, fd: io.BytesIO, field_type: str, count: int,
|
||||
value: Union[bytes, int]) -> None:
|
||||
if not fd: raise Exception()
|
||||
assert isinstance(count, int) and count >= 0, f"{count!r} must be non-neg int"
|
||||
if count == 0:
|
||||
return
|
||||
type_len = None
|
||||
if field_type == 'byte':
|
||||
type_len = 1
|
||||
elif field_type == 'u16':
|
||||
type_len = 2
|
||||
elif field_type == 'u32':
|
||||
type_len = 4
|
||||
elif field_type == 'u64':
|
||||
type_len = 8
|
||||
# TODO tu16/tu32/tu64
|
||||
elif field_type == 'chain_hash':
|
||||
type_len = 32
|
||||
elif field_type == 'channel_id':
|
||||
type_len = 32
|
||||
elif field_type == 'sha256':
|
||||
type_len = 32
|
||||
elif field_type == 'signature':
|
||||
type_len = 64
|
||||
elif field_type == 'point':
|
||||
type_len = 33
|
||||
elif field_type == 'short_channel_id':
|
||||
type_len = 8
|
||||
if type_len is None:
|
||||
raise UnknownMsgFieldType(f"unexpected fundamental type: {field_type!r}")
|
||||
total_len = count * type_len
|
||||
if isinstance(value, int) and (count == 1 or field_type == 'byte'):
|
||||
value = int.to_bytes(value, length=total_len, byteorder="big", signed=False)
|
||||
if not isinstance(value, (bytes, bytearray)):
|
||||
raise Exception(f"can only write bytes into fd. got: {value!r}")
|
||||
if total_len != len(value):
|
||||
raise Exception(f"unexpected field size. expected: {total_len}, got {len(value)}")
|
||||
nbytes_written = fd.write(value)
|
||||
if nbytes_written != len(value):
|
||||
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
|
||||
|
||||
|
||||
class LNSerializer:
|
||||
def __init__(self):
|
||||
message_types = {}
|
||||
path = os.path.join(os.path.dirname(__file__), 'lightning.json')
|
||||
with open(path) as f:
|
||||
structured = json.loads(f.read(), object_pairs_hook=OrderedDict)
|
||||
self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]]
|
||||
self.msg_type_from_name = {} # type: Dict[str, bytes]
|
||||
path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv")
|
||||
with open(path, newline='') as f:
|
||||
csvreader = csv.reader(f)
|
||||
for row in csvreader:
|
||||
#print(f">>> {row!r}")
|
||||
if row[0] == "msgtype":
|
||||
msg_type_name = row[1]
|
||||
msg_type_int = int(row[2])
|
||||
msg_type_bytes = msg_type_int.to_bytes(2, 'big')
|
||||
assert msg_type_bytes not in self.msg_scheme_from_type, f"type collision? for {msg_type_name}"
|
||||
assert msg_type_name not in self.msg_type_from_name, f"type collision? for {msg_type_name}"
|
||||
row[2] = msg_type_int
|
||||
self.msg_scheme_from_type[msg_type_bytes] = [tuple(row)]
|
||||
self.msg_type_from_name[msg_type_name] = msg_type_bytes
|
||||
elif row[0] == "msgdata":
|
||||
assert msg_type_name == row[1]
|
||||
self.msg_scheme_from_type[msg_type_bytes].append(tuple(row))
|
||||
else:
|
||||
pass # TODO
|
||||
|
||||
for msg_name in structured:
|
||||
v = structured[msg_name]
|
||||
# these message types are skipped since their types collide
|
||||
# (for example with pong, which also uses type=19)
|
||||
# we don't need them yet
|
||||
if msg_name in ["final_incorrect_cltv_expiry", "final_incorrect_htlc_amount"]:
|
||||
continue
|
||||
if len(v["payload"]) == 0:
|
||||
continue
|
||||
try:
|
||||
num = int(v["type"])
|
||||
except ValueError:
|
||||
#print("skipping", k)
|
||||
continue
|
||||
byts = num.to_bytes(2, 'big')
|
||||
assert byts not in message_types, (byts, message_types[byts].__name__, msg_name)
|
||||
names = [x.__name__ for x in message_types.values()]
|
||||
assert msg_name + "_handler" not in names, (msg_name, names)
|
||||
message_types[byts] = _make_handler(msg_name, v)
|
||||
message_types[byts].__name__ = msg_name + "_handler"
|
||||
|
||||
assert message_types[b"\x00\x10"].__name__ == "init_handler"
|
||||
self.structured = structured
|
||||
self.message_types = message_types
|
||||
|
||||
def encode_msg(self, msg_type : str, **kwargs) -> bytes:
|
||||
def encode_msg(self, msg_type: str, **kwargs) -> bytes:
|
||||
"""
|
||||
Encode kwargs into a Lightning message (bytes)
|
||||
of the type given in the msg_type string
|
||||
"""
|
||||
typ = self.structured[msg_type]
|
||||
data = int(typ["type"]).to_bytes(2, 'big')
|
||||
lengths = {}
|
||||
for k in typ["payload"]:
|
||||
poslenMap = typ["payload"][k]
|
||||
if k not in kwargs and "feature" in poslenMap:
|
||||
continue
|
||||
param = kwargs.get(k, 0)
|
||||
leng = _eval_exp_with_ctx(poslenMap["length"], lengths)
|
||||
try:
|
||||
clone = dict(lengths)
|
||||
clone.update(kwargs)
|
||||
leng = _eval_exp_with_ctx(poslenMap["length"], clone)
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
if not isinstance(param, bytes):
|
||||
assert isinstance(param, int), "field {} is neither bytes or int".format(k)
|
||||
param = param.to_bytes(leng, 'big')
|
||||
except ValueError:
|
||||
raise Exception("{} does not fit in {} bytes".format(k, leng))
|
||||
lengths[k] = len(param)
|
||||
if lengths[k] != leng:
|
||||
raise Exception("field {} is {} bytes long, should be {} bytes long".format(k, lengths[k], leng))
|
||||
data += param
|
||||
return data
|
||||
#print(f">>> encode_msg. msg_type={msg_type}, payload={kwargs!r}")
|
||||
msg_type_bytes = self.msg_type_from_name[msg_type]
|
||||
scheme = self.msg_scheme_from_type[msg_type_bytes]
|
||||
with io.BytesIO() as fd:
|
||||
fd.write(msg_type_bytes)
|
||||
for row in scheme:
|
||||
if row[0] == "msgtype":
|
||||
pass
|
||||
elif row[0] == "msgdata":
|
||||
field_name = row[2]
|
||||
field_type = row[3]
|
||||
field_count_str = row[4]
|
||||
#print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}")
|
||||
if field_count_str == "":
|
||||
field_count = 1
|
||||
else:
|
||||
try:
|
||||
field_count = int(field_count_str)
|
||||
except ValueError:
|
||||
field_count = kwargs[field_count_str]
|
||||
if isinstance(field_count, (bytes, bytearray)):
|
||||
field_count = int.from_bytes(field_count, byteorder="big")
|
||||
assert isinstance(field_count, int)
|
||||
try:
|
||||
field_value = kwargs[field_name]
|
||||
except KeyError:
|
||||
if len(row) > 5:
|
||||
break # optional feature field not present
|
||||
else:
|
||||
field_value = 0 # default mandatory fields to zero
|
||||
#print(f">>> encode_msg. writing field: {field_name}. value={field_value!r}. field_type={field_type!r}. count={field_count!r}")
|
||||
try:
|
||||
_write_field(fd=fd,
|
||||
field_type=field_type,
|
||||
count=field_count,
|
||||
value=field_value)
|
||||
#print(f">>> encode_msg. so far: {fd.getvalue().hex()}")
|
||||
except UnknownMsgFieldType as e:
|
||||
pass # TODO
|
||||
else:
|
||||
pass # TODO
|
||||
return fd.getvalue()
|
||||
|
||||
def decode_msg(self, data : bytes) -> Tuple[str, dict]:
|
||||
def decode_msg(self, data: bytes) -> Tuple[str, dict]:
|
||||
"""
|
||||
Decode Lightning message by reading the first
|
||||
two bytes to determine message type.
|
||||
|
||||
Returns message type string and parsed message contents dict
|
||||
"""
|
||||
typ = data[:2]
|
||||
k, parsed = self.message_types[typ](data[2:])
|
||||
return k, parsed
|
||||
#print(f"decode_msg >>> {data.hex()}")
|
||||
assert len(data) >= 2
|
||||
msg_type_bytes = data[:2]
|
||||
msg_type_int = int.from_bytes(msg_type_bytes, byteorder="big", signed=False)
|
||||
scheme = self.msg_scheme_from_type[msg_type_bytes]
|
||||
assert scheme[0][2] == msg_type_int
|
||||
msg_type_name = scheme[0][1]
|
||||
parsed = {}
|
||||
with io.BytesIO(data[2:]) as fd:
|
||||
for row in scheme:
|
||||
#print(f"row: {row!r}")
|
||||
if row[0] == "msgtype":
|
||||
pass
|
||||
elif row[0] == "msgdata":
|
||||
field_name = row[2]
|
||||
field_type = row[3]
|
||||
field_count_str = row[4]
|
||||
if field_count_str == "":
|
||||
field_count = 1
|
||||
else:
|
||||
try:
|
||||
field_count = int(field_count_str)
|
||||
except ValueError:
|
||||
field_count = int.from_bytes(parsed[field_count_str], byteorder="big")
|
||||
#print(f">> count={field_count}. parsed={parsed}")
|
||||
try:
|
||||
parsed[field_name] = _read_field(fd=fd,
|
||||
field_type=field_type,
|
||||
count=field_count)
|
||||
except UnknownMsgFieldType as e:
|
||||
pass # TODO
|
||||
except UnexpectedEndOfStream as e:
|
||||
if len(row) > 5:
|
||||
break # optional feature field not present
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
pass # TODO
|
||||
return msg_type_name, parsed
|
||||
|
||||
|
||||
_inst = LNSerializer()
|
||||
encode_msg = _inst.encode_msg
|
||||
|
||||
Reference in New Issue
Block a user