move force_close_channel to lnbase, test it, add FORCE_CLOSING state
This commit is contained in:
@@ -19,7 +19,7 @@ import aiorpcx
|
|||||||
from .crypto import sha256, sha256d
|
from .crypto import sha256, sha256d
|
||||||
from . import bitcoin
|
from . import bitcoin
|
||||||
from . import ecc
|
from . import ecc
|
||||||
from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string
|
from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string, der_sig_from_sig_string
|
||||||
from . import constants
|
from . import constants
|
||||||
from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabled_bits, ignore_exceptions
|
from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabled_bits, ignore_exceptions
|
||||||
from .transaction import Transaction, TxOutput
|
from .transaction import Transaction, TxOutput
|
||||||
@@ -1158,6 +1158,25 @@ class Peer(PrintError):
|
|||||||
self.print_error('Channel closed', txid)
|
self.print_error('Channel closed', txid)
|
||||||
return txid
|
return txid
|
||||||
|
|
||||||
|
async def force_close_channel(self, chan_id):
|
||||||
|
chan = self.channels[chan_id]
|
||||||
|
# local_commitment always gives back the next expected local_commitment,
|
||||||
|
# but in this case, we want the current one. So substract one ctn number
|
||||||
|
old_local_state = chan.config[LOCAL]
|
||||||
|
chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1)
|
||||||
|
tx = chan.pending_local_commitment
|
||||||
|
chan.config[LOCAL] = old_local_state
|
||||||
|
tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)})
|
||||||
|
remote_sig = chan.config[LOCAL].current_commitment_signature
|
||||||
|
remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01"
|
||||||
|
none_idx = tx._inputs[0]["signatures"].index(None)
|
||||||
|
tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
|
||||||
|
assert tx.is_complete()
|
||||||
|
# TODO persist FORCE_CLOSING state to disk
|
||||||
|
chan.set_state('FORCE_CLOSING')
|
||||||
|
self.lnworker.save_channel(chan)
|
||||||
|
return await self.network.broadcast_transaction(tx)
|
||||||
|
|
||||||
@log_exceptions
|
@log_exceptions
|
||||||
async def on_shutdown(self, payload):
|
async def on_shutdown(self, payload):
|
||||||
# length of scripts allowed in BOLT-02
|
# length of scripts allowed in BOLT-02
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
|
|||||||
import threading
|
import threading
|
||||||
import socket
|
import socket
|
||||||
import json
|
import json
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
import dns.resolver
|
import dns.resolver
|
||||||
import dns.exception
|
import dns.exception
|
||||||
@@ -267,18 +268,13 @@ class LNWorker(PrintError):
|
|||||||
return addr, peer, fut
|
return addr, peer, fut
|
||||||
|
|
||||||
def _pay(self, invoice, amount_sat=None):
|
def _pay(self, invoice, amount_sat=None):
|
||||||
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
|
addr = self._check_invoice(invoice, amount_sat)
|
||||||
payment_hash = addr.paymenthash
|
route = self._create_route_from_invoice(decoded_invoice=addr)
|
||||||
amount_sat = (addr.amount * COIN) if addr.amount else amount_sat
|
peer = self.peers[route[0].node_id]
|
||||||
if amount_sat is None:
|
return addr, peer, self._pay_to_route(route, addr)
|
||||||
raise InvoiceError(_("Missing amount"))
|
|
||||||
amount_msat = int(amount_sat * 1000)
|
async def _pay_to_route(self, route, addr):
|
||||||
if addr.get_min_final_cltv_expiry() > 60 * 144:
|
short_channel_id = route[0].short_channel_id
|
||||||
raise InvoiceError("{}\n{}".format(
|
|
||||||
_("Invoice wants us to risk locking funds for unreasonably long."),
|
|
||||||
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
|
|
||||||
route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat)
|
|
||||||
node_id, short_channel_id = route[0].node_id, route[0].short_channel_id
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
channels = list(self.channels.values())
|
channels = list(self.channels.values())
|
||||||
for chan in channels:
|
for chan in channels:
|
||||||
@@ -286,11 +282,24 @@ class LNWorker(PrintError):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
|
raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
|
||||||
peer = self.peers[node_id]
|
peer = self.peers[route[0].node_id]
|
||||||
coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry())
|
return await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry())
|
||||||
return addr, peer, coro
|
|
||||||
|
|
||||||
def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]:
|
@staticmethod
|
||||||
|
def _check_invoice(invoice, amount_sat=None):
|
||||||
|
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
|
||||||
|
if amount_sat:
|
||||||
|
addr.amount = Decimal(amount_sat) / COIN
|
||||||
|
if addr.amount is None:
|
||||||
|
raise InvoiceError(_("Missing amount"))
|
||||||
|
if addr.get_min_final_cltv_expiry() > 60 * 144:
|
||||||
|
raise InvoiceError("{}\n{}".format(
|
||||||
|
_("Invoice wants us to risk locking funds for unreasonably long."),
|
||||||
|
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
|
||||||
|
return addr
|
||||||
|
|
||||||
|
def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]:
|
||||||
|
amount_msat = int(decoded_invoice.amount * COIN * 1000)
|
||||||
invoice_pubkey = decoded_invoice.pubkey.serialize()
|
invoice_pubkey = decoded_invoice.pubkey.serialize()
|
||||||
# use 'r' field from invoice
|
# use 'r' field from invoice
|
||||||
route = None # type: List[RouteEdge]
|
route = None # type: List[RouteEdge]
|
||||||
@@ -441,19 +450,8 @@ class LNWorker(PrintError):
|
|||||||
|
|
||||||
async def force_close_channel(self, chan_id):
|
async def force_close_channel(self, chan_id):
|
||||||
chan = self.channels[chan_id]
|
chan = self.channels[chan_id]
|
||||||
# local_commitment always gives back the next expected local_commitment,
|
peer = self.peers[chan.node_id]
|
||||||
# but in this case, we want the current one. So substract one ctn number
|
return await peer.force_close_channel(chan_id)
|
||||||
old_local_state = chan.config[LOCAL]
|
|
||||||
chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1)
|
|
||||||
tx = chan.pending_local_commitment
|
|
||||||
chan.config[LOCAL] = old_local_state
|
|
||||||
tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)})
|
|
||||||
remote_sig = chan.config[LOCAL].current_commitment_signature
|
|
||||||
remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01"
|
|
||||||
none_idx = tx._inputs[0]["signatures"].index(None)
|
|
||||||
tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
|
|
||||||
assert tx.is_complete()
|
|
||||||
return await self.network.broadcast_transaction(tx)
|
|
||||||
|
|
||||||
def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
|
def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from electrum.util import bh2u
|
|||||||
from electrum.lnbase import Peer, decode_msg, gen_msg
|
from electrum.lnbase import Peer, decode_msg, gen_msg
|
||||||
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
|
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
|
||||||
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
|
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
|
||||||
|
from electrum.lnutil import PaymentFailure
|
||||||
from electrum.lnrouter import ChannelDB, LNPathFinder
|
from electrum.lnrouter import ChannelDB, LNPathFinder
|
||||||
from electrum.lnworker import LNWorker
|
from electrum.lnworker import LNWorker
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ def noop_lock():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
class MockNetwork:
|
class MockNetwork:
|
||||||
def __init__(self):
|
def __init__(self, tx_queue):
|
||||||
self.callbacks = defaultdict(list)
|
self.callbacks = defaultdict(list)
|
||||||
self.lnwatcher = None
|
self.lnwatcher = None
|
||||||
user_config = {}
|
user_config = {}
|
||||||
@@ -43,6 +44,7 @@ class MockNetwork:
|
|||||||
self.channel_db = ChannelDB(self)
|
self.channel_db = ChannelDB(self)
|
||||||
self.interface = None
|
self.interface = None
|
||||||
self.path_finder = LNPathFinder(self.channel_db)
|
self.path_finder = LNPathFinder(self.channel_db)
|
||||||
|
self.tx_queue = tx_queue
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def callback_lock(self):
|
def callback_lock(self):
|
||||||
@@ -55,12 +57,16 @@ class MockNetwork:
|
|||||||
def get_local_height(self):
|
def get_local_height(self):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
async def broadcast_transaction(self, tx):
|
||||||
|
if self.tx_queue:
|
||||||
|
await self.tx_queue.put(tx)
|
||||||
|
|
||||||
class MockLNWorker:
|
class MockLNWorker:
|
||||||
def __init__(self, remote_keypair, local_keypair, chan):
|
def __init__(self, remote_keypair, local_keypair, chan, tx_queue):
|
||||||
self.chan = chan
|
self.chan = chan
|
||||||
self.remote_keypair = remote_keypair
|
self.remote_keypair = remote_keypair
|
||||||
self.node_keypair = local_keypair
|
self.node_keypair = local_keypair
|
||||||
self.network = MockNetwork()
|
self.network = MockNetwork(tx_queue)
|
||||||
self.channels = {self.chan.channel_id: self.chan}
|
self.channels = {self.chan.channel_id: self.chan}
|
||||||
self.invoices = {}
|
self.invoices = {}
|
||||||
|
|
||||||
@@ -76,10 +82,12 @@ class MockLNWorker:
|
|||||||
return self.channels
|
return self.channels
|
||||||
|
|
||||||
def save_channel(self, chan):
|
def save_channel(self, chan):
|
||||||
pass
|
print("Ignoring channel save")
|
||||||
|
|
||||||
get_invoice = LNWorker.get_invoice
|
get_invoice = LNWorker.get_invoice
|
||||||
_create_route_from_invoice = LNWorker._create_route_from_invoice
|
_create_route_from_invoice = LNWorker._create_route_from_invoice
|
||||||
|
_check_invoice = staticmethod(LNWorker._check_invoice)
|
||||||
|
_pay_to_route = LNWorker._pay_to_route
|
||||||
|
|
||||||
class MockTransport:
|
class MockTransport:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -120,18 +128,19 @@ class TestPeer(unittest.TestCase):
|
|||||||
self.alice_channel, self.bob_channel = create_test_channels()
|
self.alice_channel, self.bob_channel = create_test_channels()
|
||||||
|
|
||||||
def test_require_data_loss_protect(self):
|
def test_require_data_loss_protect(self):
|
||||||
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel)
|
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
|
||||||
mock_transport = NoFeaturesTransport()
|
mock_transport = NoFeaturesTransport()
|
||||||
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
|
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
|
||||||
mock_lnworker.peer = p1
|
mock_lnworker.peer = p1
|
||||||
with self.assertRaises(LightningPeerConnectionClosed):
|
with self.assertRaises(LightningPeerConnectionClosed):
|
||||||
asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
|
run(asyncio.wait_for(p1._main_loop(), 1))
|
||||||
|
|
||||||
def test_payment(self):
|
def prepare_peers(self):
|
||||||
k1, k2 = keypair(), keypair()
|
k1, k2 = keypair(), keypair()
|
||||||
t1, t2 = transport_pair()
|
t1, t2 = transport_pair()
|
||||||
w1 = MockLNWorker(k1, k2, self.alice_channel)
|
q1, q2 = asyncio.Queue(), asyncio.Queue()
|
||||||
w2 = MockLNWorker(k2, k1, self.bob_channel)
|
w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1)
|
||||||
|
w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2)
|
||||||
p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
|
p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
|
||||||
request_initial_sync=False, transport=t1)
|
request_initial_sync=False, transport=t1)
|
||||||
p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
|
p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
|
||||||
@@ -145,6 +154,11 @@ class TestPeer(unittest.TestCase):
|
|||||||
# this populates the channel graph:
|
# this populates the channel graph:
|
||||||
p1.mark_open(self.alice_channel)
|
p1.mark_open(self.alice_channel)
|
||||||
p2.mark_open(self.bob_channel)
|
p2.mark_open(self.bob_channel)
|
||||||
|
return p1, p2, w1, w2, q1, q2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_invoice(w2 # receiver
|
||||||
|
):
|
||||||
amount_btc = 100000/Decimal(COIN)
|
amount_btc = 100000/Decimal(COIN)
|
||||||
payment_preimage = os.urandom(32)
|
payment_preimage = os.urandom(32)
|
||||||
RHASH = sha256(payment_preimage)
|
RHASH = sha256(payment_preimage)
|
||||||
@@ -156,13 +170,23 @@ class TestPeer(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
pay_req = lnencode(addr, w2.node_keypair.privkey)
|
pay_req = lnencode(addr, w2.node_keypair.privkey)
|
||||||
w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req)
|
w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req)
|
||||||
l = asyncio.get_event_loop()
|
return pay_req
|
||||||
async def pay():
|
|
||||||
fut = asyncio.Future()
|
|
||||||
def evt_set(event, _lnworker, msg):
|
|
||||||
fut.set_result(msg)
|
|
||||||
w2.network.register_callback(evt_set, ['ln_message'])
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_ln_message_future(w2 # receiver
|
||||||
|
):
|
||||||
|
fut = asyncio.Future()
|
||||||
|
def evt_set(event, _lnworker, msg):
|
||||||
|
fut.set_result(msg)
|
||||||
|
w2.network.register_callback(evt_set, ['ln_message'])
|
||||||
|
return fut
|
||||||
|
|
||||||
|
def test_payment(self):
|
||||||
|
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers()
|
||||||
|
pay_req = self.prepare_invoice(w2)
|
||||||
|
fut = self.prepare_ln_message_future(w2)
|
||||||
|
|
||||||
|
async def pay():
|
||||||
addr, peer, coro = LNWorker._pay(w1, pay_req)
|
addr, peer, coro = LNWorker._pay(w1, pay_req)
|
||||||
await coro
|
await coro
|
||||||
print("HTLC ADDED")
|
print("HTLC ADDED")
|
||||||
@@ -170,4 +194,28 @@ class TestPeer(unittest.TestCase):
|
|||||||
gath.cancel()
|
gath.cancel()
|
||||||
gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop())
|
gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop())
|
||||||
with self.assertRaises(asyncio.CancelledError):
|
with self.assertRaises(asyncio.CancelledError):
|
||||||
l.run_until_complete(gath)
|
run(gath)
|
||||||
|
|
||||||
|
def test_channel_usage_after_closing(self):
|
||||||
|
p1, p2, w1, w2, q1, q2 = self.prepare_peers()
|
||||||
|
pay_req = self.prepare_invoice(w2)
|
||||||
|
|
||||||
|
addr = w1._check_invoice(pay_req)
|
||||||
|
route = w1._create_route_from_invoice(decoded_invoice=addr)
|
||||||
|
|
||||||
|
run(p1.force_close_channel(self.alice_channel.channel_id))
|
||||||
|
# check if a tx (commitment transaction) was broadcasted:
|
||||||
|
assert q1.qsize() == 1
|
||||||
|
|
||||||
|
with self.assertRaises(PaymentFailure) as e:
|
||||||
|
w1._create_route_from_invoice(decoded_invoice=addr)
|
||||||
|
self.assertEqual(str(e.exception), 'No path found')
|
||||||
|
|
||||||
|
peer = w1.peers[route[0].node_id]
|
||||||
|
# AssertionError is ok since we shouldn't use old routes, and the
|
||||||
|
# route finding should fail when channel is closed
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
run(asyncio.gather(w1._pay_to_route(route, addr), p1._main_loop(), p2._main_loop()))
|
||||||
|
|
||||||
|
def run(coro):
|
||||||
|
asyncio.get_event_loop().run_until_complete(coro)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from electrum import lnchan
|
|||||||
from electrum import lnutil
|
from electrum import lnutil
|
||||||
from electrum import bip32 as bip32_utils
|
from electrum import bip32 as bip32_utils
|
||||||
from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED
|
from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED
|
||||||
|
from electrum.ecc import sig_string_from_der_sig
|
||||||
|
|
||||||
one_bitcoin_in_msat = bitcoin.COIN * 1000
|
one_bitcoin_in_msat = bitcoin.COIN * 1000
|
||||||
|
|
||||||
@@ -81,7 +82,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
|
|||||||
per_commitment_secret_seed=seed,
|
per_commitment_secret_seed=seed,
|
||||||
funding_locked_received=True,
|
funding_locked_received=True,
|
||||||
was_announced=False,
|
was_announced=False,
|
||||||
current_commitment_signature=None,
|
# just a random signature
|
||||||
|
current_commitment_signature=sig_string_from_der_sig(bytes.fromhex('3046022100c66e112e22b91b96b795a6dd5f4b004f3acccd9a2a31bf104840f256855b7aa3022100e711b868b62d87c7edd95a2370e496b9cb6a38aff13c9f64f9ff2f3b2a0052dd')),
|
||||||
current_htlc_signatures=None,
|
current_htlc_signatures=None,
|
||||||
),
|
),
|
||||||
"constraints":lnbase.ChannelConstraints(
|
"constraints":lnbase.ChannelConstraints(
|
||||||
@@ -185,6 +187,14 @@ class TestChannel(unittest.TestCase):
|
|||||||
|
|
||||||
self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0]
|
self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0]
|
||||||
|
|
||||||
|
def test_concurrent_reversed_payment(self):
|
||||||
|
self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02')
|
||||||
|
self.htlc_dict['amount_msat'] += 1000
|
||||||
|
bob_idx = self.bob_channel.add_htlc(self.htlc_dict)
|
||||||
|
alice_idx = self.alice_channel.receive_htlc(self.htlc_dict)
|
||||||
|
self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment())
|
||||||
|
self.assertEquals(len(self.alice_channel.pending_remote_commitment.outputs()), 3)
|
||||||
|
|
||||||
def test_SimpleAddSettleWorkflow(self):
|
def test_SimpleAddSettleWorkflow(self):
|
||||||
alice_channel, bob_channel = self.alice_channel, self.bob_channel
|
alice_channel, bob_channel = self.alice_channel, self.bob_channel
|
||||||
htlc = self.htlc
|
htlc = self.htlc
|
||||||
|
|||||||
Reference in New Issue
Block a user