1
0

move connection string decoding to lnworker, fix test_lnutil

This commit is contained in:
Janus
2018-09-27 16:43:33 +02:00
committed by ThomasV
parent 24cf4e7eb0
commit efc8d50570
5 changed files with 133 additions and 71 deletions

View File

@@ -5,8 +5,7 @@ from PyQt5.QtWidgets import *
from electrum.util import inv_dict, bh2u, bfh
from electrum.i18n import _
from electrum.lnhtlc import HTLCStateMachine
from electrum.lnaddr import lndecode
from electrum.lnutil import LOCAL, REMOTE
from electrum.lnutil import LOCAL, REMOTE, ConnStringFormatError
from .util import MyTreeWidget, SortableTreeWidgetItem, WindowModalDialog, Buttons, OkButton, CancelButton
from .amountedit import BTCAmountEdit
@@ -108,55 +107,12 @@ class ChannelsList(MyTreeWidget):
return
local_amt = local_amt_inp.get_amount()
push_amt = push_amt_inp.get_amount()
connect_contents = str(remote_nodeid.text())
nodeid_hex, rest = self.parse_connect_contents(connect_contents)
connect_contents = str(remote_nodeid.text()).strip()
try:
node_id = bfh(nodeid_hex)
assert len(node_id) == 33
except:
self.parent.show_error(_('Invalid node ID, must be 33 bytes and hexadecimal'))
return
peer = lnworker.peers.get(node_id)
if not peer:
all_nodes = self.parent.network.channel_db.nodes
node_info = all_nodes.get(node_id, None)
if rest is not None:
try:
host, port = rest.split(":")
except ValueError:
self.parent.show_error(_('Connection strings must be in <node_pubkey>@<host>:<port> format'))
return
elif node_info:
host, port = node_info.addresses[0]
else:
self.parent.show_error(_('Unknown node:') + ' ' + nodeid_hex)
return
try:
int(port)
except:
self.parent.show_error(_('Port number must be decimal'))
return
lnworker.add_peer(host, port, node_id)
self.main_window.protect(self.open_channel, (node_id, local_amt, push_amt))
@classmethod
def parse_connect_contents(cls, connect_contents: str):
rest = None
try:
# connection string?
nodeid_hex, rest = connect_contents.split("@")
except ValueError:
try:
# invoice?
invoice = lndecode(connect_contents)
nodeid_bytes = invoice.pubkey.serialize()
nodeid_hex = bh2u(nodeid_bytes)
except:
# node id as hex?
nodeid_hex = connect_contents
return nodeid_hex, rest
self.main_window.protect(self.open_channel, (connect_contents, local_amt, push_amt))
except ConnStringFormatError as e:
self.parent.show_error(str(e))
def open_channel(self, *args, **kwargs):
self.parent.wallet.lnworker.open_channel(*args, **kwargs)