1
0

ln: add lightning_listen config option

This commit is contained in:
Janus
2018-10-16 17:45:28 +02:00
committed by ThomasV
parent 52377dbfa0
commit 962f70c7da
4 changed files with 170 additions and 39 deletions

View File

@@ -23,7 +23,6 @@ class HandshakeState(object):
self.h = sha256(self.h + data)
return self.h
def get_nonce_bytes(n):
"""BOLT 8 requires the nonce to be 12 bytes, 4 bytes leading
zeroes and 8 bytes little endian encoded 64 bit integer.
@@ -60,29 +59,22 @@ def get_bolt8_hkdf(salt, ikm):
return T1, T2
def act1_initiator_message(hs, epriv, epub):
hs.update(epub)
ss = get_ecdh(epriv, hs.responder_pub)
ck2, temp_k1 = get_bolt8_hkdf(hs.ck, ss)
hs.ck = ck2
c = aead_encrypt(temp_k1, 0, hs.h, b"")
c = aead_encrypt(temp_k1, 0, hs.update(epub), b"")
#for next step if we do it
hs.update(c)
msg = hs.handshake_version + epub + c
assert len(msg) == 50
return msg
return msg, temp_k1
def create_ephemeral_key() -> (bytes, bytes):
privkey = ecc.ECPrivkey.generate_random_key()
return privkey.get_secret_bytes(), privkey.get_public_key_bytes()
class LNTransport:
def __init__(self, privkey, remote_pubkey, reader, writer):
self.privkey = privkey
self.remote_pubkey = remote_pubkey
self.reader = reader
self.writer = writer
class LNTransportBase:
def send_bytes(self, msg):
l = len(msg).to_bytes(2, 'big')
lc = aead_encrypt(self.sk, self.sn(), b'', l)
@@ -116,12 +108,97 @@ class LNTransport:
raise LightningPeerConnectionClosed()
read_buffer += s
def rn(self):
o = self._rn, self.rk
self._rn += 1
if self._rn == 1000:
self.r_ck, self.rk = get_bolt8_hkdf(self.r_ck, self.rk)
self._rn = 0
return o
def sn(self):
o = self._sn
self._sn += 1
if self._sn == 1000:
self.s_ck, self.sk = get_bolt8_hkdf(self.s_ck, self.sk)
self._sn = 0
return o
def init_counters(self, ck):
# init counters
self._sn = 0
self._rn = 0
self.r_ck = ck
self.s_ck = ck
class LNResponderTransport(LNTransportBase):
def __init__(self, privkey, reader, writer):
self.privkey = privkey
self.reader = reader
self.writer = writer
async def handshake(self, **kwargs):
hs = HandshakeState(privkey_to_pubkey(self.privkey))
act1 = b''
while len(act1) < 50:
act1 += await self.reader.read(50 - len(act1))
if len(act1) != 50:
raise HandshakeFailed('responder: short act 1 read, length is ' + str(len(act1)))
if bytes([act1[0]]) != HandshakeState.handshake_version:
raise HandshakeFailed('responder: bad handshake version in act 1')
c = act1[-16:]
re = act1[1:34]
h = hs.update(re)
ss = get_ecdh(self.privkey, re)
ck, temp_k1 = get_bolt8_hkdf(sha256(HandshakeState.protocol_name), ss)
_p = aead_decrypt(temp_k1, 0, h, c)
hs.update(c)
# act 2
if 'epriv' not in kwargs:
epriv, epub = create_ephemeral_key()
else:
epriv = kwargs['epriv']
epub = ecc.ECPrivkey(epriv).get_public_key_bytes()
hs.ck = ck
hs.responder_pub = re
msg, temp_k2 = act1_initiator_message(hs, epriv, epub)
self.writer.write(msg)
# act 3
act3 = b''
while len(act3) < 66:
act3 += await self.reader.read(66 - len(act3))
if len(act3) != 66:
raise HandshakeFailed('responder: short act 3 read, length is ' + str(len(act3)))
if bytes([act3[0]]) != HandshakeState.handshake_version:
raise HandshakeFailed('responder: bad handshake version in act 3')
c = act3[1:50]
t = act3[-16:]
rs = aead_decrypt(temp_k2, 1, hs.h, c)
ss = get_ecdh(epriv, rs)
ck, temp_k3 = get_bolt8_hkdf(hs.ck, ss)
_p = aead_decrypt(temp_k3, 0, hs.update(c), t)
self.rk, self.sk = get_bolt8_hkdf(ck, b'')
self.init_counters(ck)
return rs
class LNTransport(LNTransportBase):
def __init__(self, privkey, remote_pubkey, reader, writer):
assert type(privkey) is bytes and len(privkey) == 32
self.privkey = privkey
self.remote_pubkey = remote_pubkey
self.reader = reader
self.writer = writer
async def handshake(self):
hs = HandshakeState(self.remote_pubkey)
# Get a new ephemeral key
epriv, epub = create_ephemeral_key()
msg = act1_initiator_message(hs, epriv, epub)
msg, _temp_k1 = act1_initiator_message(hs, epriv, epub)
# act 1
self.writer.write(msg)
rspns = await self.reader.read(2**10)
@@ -145,27 +222,7 @@ class LNTransport:
ck, temp_k3 = get_bolt8_hkdf(hs.ck, ss)
hs.ck = ck
t = aead_encrypt(temp_k3, 0, hs.h, b'')
self.sk, self.rk = get_bolt8_hkdf(hs.ck, b'')
msg = hs.handshake_version + c + t
self.writer.write(msg)
# init counters
self._sn = 0
self._rn = 0
self.r_ck = ck
self.s_ck = ck
def rn(self):
o = self._rn, self.rk
self._rn += 1
if self._rn == 1000:
self.r_ck, self.rk = get_bolt8_hkdf(self.r_ck, self.rk)
self._rn = 0
return o
def sn(self):
o = self._sn
self._sn += 1
if self._sn == 1000:
self.s_ck, self.sk = get_bolt8_hkdf(self.s_ck, self.sk)
self._sn = 0
return o
self.sk, self.rk = get_bolt8_hkdf(hs.ck, b'')
self.init_counters(ck)