1
0

lnonion: implement basis of varonion support

This commit is contained in:
SomberNight
2020-03-24 12:12:36 +01:00
parent 6ba08cc8d4
commit a66437f399
5 changed files with 326 additions and 107 deletions

View File

@@ -4,6 +4,8 @@ import io
from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional
from collections import OrderedDict
from .lnutil import OnionFailureCodeMetaFlag
class MalformedMsg(Exception): pass
class UnknownMsgFieldType(MalformedMsg): pass
@@ -254,8 +256,19 @@ def _resolve_field_count(field_count_str: str, *, vars_dict: dict, allow_any=Fal
return field_count
def _parse_msgtype_intvalue_for_onion_wire(value: str) -> int:
msg_type_int = 0
for component in value.split("|"):
try:
msg_type_int |= int(component)
except ValueError:
msg_type_int |= OnionFailureCodeMetaFlag[component]
return msg_type_int
class LNSerializer:
def __init__(self):
def __init__(self, *, for_onion_wire: bool = False):
# TODO msg_type could be 'int' everywhere...
self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]]
self.msg_type_from_name = {} # type: Dict[str, bytes]
@@ -264,7 +277,10 @@ class LNSerializer:
self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]]
self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]]
path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv")
if for_onion_wire:
path = os.path.join(os.path.dirname(__file__), "lnwire", "onion_wire.csv")
else:
path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv")
with open(path, newline='') as f:
csvreader = csv.reader(f)
for row in csvreader:
@@ -272,7 +288,10 @@ class LNSerializer:
if row[0] == "msgtype":
# msgtype,<msgname>,<value>[,<option>]
msg_type_name = row[1]
msg_type_int = int(row[2])
if for_onion_wire:
msg_type_int = _parse_msgtype_intvalue_for_onion_wire(str(row[2]))
else:
msg_type_int = int(row[2])
msg_type_bytes = msg_type_int.to_bytes(2, 'big')
assert msg_type_bytes not in self.msg_scheme_from_type, f"type collision? for {msg_type_name}"
assert msg_type_name not in self.msg_type_from_name, f"type collision? for {msg_type_name}"
@@ -475,3 +494,6 @@ class LNSerializer:
_inst = LNSerializer()
encode_msg = _inst.encode_msg
decode_msg = _inst.decode_msg
OnionWireSerializer = LNSerializer(for_onion_wire=True)