move force_close_channel to lnbase, test it, add FORCE_CLOSING state
This commit is contained in:
@@ -19,7 +19,7 @@ import aiorpcx
|
||||
from .crypto import sha256, sha256d
|
||||
from . import bitcoin
|
||||
from . import ecc
|
||||
from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string
|
||||
from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string, der_sig_from_sig_string
|
||||
from . import constants
|
||||
from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabled_bits, ignore_exceptions
|
||||
from .transaction import Transaction, TxOutput
|
||||
@@ -1158,6 +1158,25 @@ class Peer(PrintError):
|
||||
self.print_error('Channel closed', txid)
|
||||
return txid
|
||||
|
||||
async def force_close_channel(self, chan_id):
|
||||
chan = self.channels[chan_id]
|
||||
# local_commitment always gives back the next expected local_commitment,
|
||||
# but in this case, we want the current one. So substract one ctn number
|
||||
old_local_state = chan.config[LOCAL]
|
||||
chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1)
|
||||
tx = chan.pending_local_commitment
|
||||
chan.config[LOCAL] = old_local_state
|
||||
tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)})
|
||||
remote_sig = chan.config[LOCAL].current_commitment_signature
|
||||
remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01"
|
||||
none_idx = tx._inputs[0]["signatures"].index(None)
|
||||
tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
|
||||
assert tx.is_complete()
|
||||
# TODO persist FORCE_CLOSING state to disk
|
||||
chan.set_state('FORCE_CLOSING')
|
||||
self.lnworker.save_channel(chan)
|
||||
return await self.network.broadcast_transaction(tx)
|
||||
|
||||
@log_exceptions
|
||||
async def on_shutdown(self, payload):
|
||||
# length of scripts allowed in BOLT-02
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
|
||||
import threading
|
||||
import socket
|
||||
import json
|
||||
from decimal import Decimal
|
||||
|
||||
import dns.resolver
|
||||
import dns.exception
|
||||
@@ -267,18 +268,13 @@ class LNWorker(PrintError):
|
||||
return addr, peer, fut
|
||||
|
||||
def _pay(self, invoice, amount_sat=None):
|
||||
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
|
||||
payment_hash = addr.paymenthash
|
||||
amount_sat = (addr.amount * COIN) if addr.amount else amount_sat
|
||||
if amount_sat is None:
|
||||
raise InvoiceError(_("Missing amount"))
|
||||
amount_msat = int(amount_sat * 1000)
|
||||
if addr.get_min_final_cltv_expiry() > 60 * 144:
|
||||
raise InvoiceError("{}\n{}".format(
|
||||
_("Invoice wants us to risk locking funds for unreasonably long."),
|
||||
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
|
||||
route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat)
|
||||
node_id, short_channel_id = route[0].node_id, route[0].short_channel_id
|
||||
addr = self._check_invoice(invoice, amount_sat)
|
||||
route = self._create_route_from_invoice(decoded_invoice=addr)
|
||||
peer = self.peers[route[0].node_id]
|
||||
return addr, peer, self._pay_to_route(route, addr)
|
||||
|
||||
async def _pay_to_route(self, route, addr):
|
||||
short_channel_id = route[0].short_channel_id
|
||||
with self.lock:
|
||||
channels = list(self.channels.values())
|
||||
for chan in channels:
|
||||
@@ -286,11 +282,24 @@ class LNWorker(PrintError):
|
||||
break
|
||||
else:
|
||||
raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
|
||||
peer = self.peers[node_id]
|
||||
coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry())
|
||||
return addr, peer, coro
|
||||
peer = self.peers[route[0].node_id]
|
||||
return await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry())
|
||||
|
||||
def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]:
|
||||
@staticmethod
|
||||
def _check_invoice(invoice, amount_sat=None):
|
||||
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
|
||||
if amount_sat:
|
||||
addr.amount = Decimal(amount_sat) / COIN
|
||||
if addr.amount is None:
|
||||
raise InvoiceError(_("Missing amount"))
|
||||
if addr.get_min_final_cltv_expiry() > 60 * 144:
|
||||
raise InvoiceError("{}\n{}".format(
|
||||
_("Invoice wants us to risk locking funds for unreasonably long."),
|
||||
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
|
||||
return addr
|
||||
|
||||
def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]:
|
||||
amount_msat = int(decoded_invoice.amount * COIN * 1000)
|
||||
invoice_pubkey = decoded_invoice.pubkey.serialize()
|
||||
# use 'r' field from invoice
|
||||
route = None # type: List[RouteEdge]
|
||||
@@ -441,19 +450,8 @@ class LNWorker(PrintError):
|
||||
|
||||
async def force_close_channel(self, chan_id):
|
||||
chan = self.channels[chan_id]
|
||||
# local_commitment always gives back the next expected local_commitment,
|
||||
# but in this case, we want the current one. So substract one ctn number
|
||||
old_local_state = chan.config[LOCAL]
|
||||
chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1)
|
||||
tx = chan.pending_local_commitment
|
||||
chan.config[LOCAL] = old_local_state
|
||||
tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)})
|
||||
remote_sig = chan.config[LOCAL].current_commitment_signature
|
||||
remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01"
|
||||
none_idx = tx._inputs[0]["signatures"].index(None)
|
||||
tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
|
||||
assert tx.is_complete()
|
||||
return await self.network.broadcast_transaction(tx)
|
||||
peer = self.peers[chan.node_id]
|
||||
return await peer.force_close_channel(chan_id)
|
||||
|
||||
def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
|
||||
now = time.time()
|
||||
|
||||
@@ -16,6 +16,7 @@ from electrum.util import bh2u
|
||||
from electrum.lnbase import Peer, decode_msg, gen_msg
|
||||
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
|
||||
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
|
||||
from electrum.lnutil import PaymentFailure
|
||||
from electrum.lnrouter import ChannelDB, LNPathFinder
|
||||
from electrum.lnworker import LNWorker
|
||||
|
||||
@@ -33,7 +34,7 @@ def noop_lock():
|
||||
yield
|
||||
|
||||
class MockNetwork:
|
||||
def __init__(self):
|
||||
def __init__(self, tx_queue):
|
||||
self.callbacks = defaultdict(list)
|
||||
self.lnwatcher = None
|
||||
user_config = {}
|
||||
@@ -43,6 +44,7 @@ class MockNetwork:
|
||||
self.channel_db = ChannelDB(self)
|
||||
self.interface = None
|
||||
self.path_finder = LNPathFinder(self.channel_db)
|
||||
self.tx_queue = tx_queue
|
||||
|
||||
@property
|
||||
def callback_lock(self):
|
||||
@@ -55,12 +57,16 @@ class MockNetwork:
|
||||
def get_local_height(self):
|
||||
return 0
|
||||
|
||||
async def broadcast_transaction(self, tx):
|
||||
if self.tx_queue:
|
||||
await self.tx_queue.put(tx)
|
||||
|
||||
class MockLNWorker:
|
||||
def __init__(self, remote_keypair, local_keypair, chan):
|
||||
def __init__(self, remote_keypair, local_keypair, chan, tx_queue):
|
||||
self.chan = chan
|
||||
self.remote_keypair = remote_keypair
|
||||
self.node_keypair = local_keypair
|
||||
self.network = MockNetwork()
|
||||
self.network = MockNetwork(tx_queue)
|
||||
self.channels = {self.chan.channel_id: self.chan}
|
||||
self.invoices = {}
|
||||
|
||||
@@ -76,10 +82,12 @@ class MockLNWorker:
|
||||
return self.channels
|
||||
|
||||
def save_channel(self, chan):
|
||||
pass
|
||||
print("Ignoring channel save")
|
||||
|
||||
get_invoice = LNWorker.get_invoice
|
||||
_create_route_from_invoice = LNWorker._create_route_from_invoice
|
||||
_check_invoice = staticmethod(LNWorker._check_invoice)
|
||||
_pay_to_route = LNWorker._pay_to_route
|
||||
|
||||
class MockTransport:
|
||||
def __init__(self):
|
||||
@@ -120,18 +128,19 @@ class TestPeer(unittest.TestCase):
|
||||
self.alice_channel, self.bob_channel = create_test_channels()
|
||||
|
||||
def test_require_data_loss_protect(self):
|
||||
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel)
|
||||
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
|
||||
mock_transport = NoFeaturesTransport()
|
||||
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
|
||||
mock_lnworker.peer = p1
|
||||
with self.assertRaises(LightningPeerConnectionClosed):
|
||||
asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
|
||||
run(asyncio.wait_for(p1._main_loop(), 1))
|
||||
|
||||
def test_payment(self):
|
||||
def prepare_peers(self):
|
||||
k1, k2 = keypair(), keypair()
|
||||
t1, t2 = transport_pair()
|
||||
w1 = MockLNWorker(k1, k2, self.alice_channel)
|
||||
w2 = MockLNWorker(k2, k1, self.bob_channel)
|
||||
q1, q2 = asyncio.Queue(), asyncio.Queue()
|
||||
w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1)
|
||||
w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2)
|
||||
p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
|
||||
request_initial_sync=False, transport=t1)
|
||||
p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
|
||||
@@ -145,6 +154,11 @@ class TestPeer(unittest.TestCase):
|
||||
# this populates the channel graph:
|
||||
p1.mark_open(self.alice_channel)
|
||||
p2.mark_open(self.bob_channel)
|
||||
return p1, p2, w1, w2, q1, q2
|
||||
|
||||
@staticmethod
|
||||
def prepare_invoice(w2 # receiver
|
||||
):
|
||||
amount_btc = 100000/Decimal(COIN)
|
||||
payment_preimage = os.urandom(32)
|
||||
RHASH = sha256(payment_preimage)
|
||||
@@ -156,13 +170,23 @@ class TestPeer(unittest.TestCase):
|
||||
])
|
||||
pay_req = lnencode(addr, w2.node_keypair.privkey)
|
||||
w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req)
|
||||
l = asyncio.get_event_loop()
|
||||
async def pay():
|
||||
fut = asyncio.Future()
|
||||
def evt_set(event, _lnworker, msg):
|
||||
fut.set_result(msg)
|
||||
w2.network.register_callback(evt_set, ['ln_message'])
|
||||
return pay_req
|
||||
|
||||
@staticmethod
|
||||
def prepare_ln_message_future(w2 # receiver
|
||||
):
|
||||
fut = asyncio.Future()
|
||||
def evt_set(event, _lnworker, msg):
|
||||
fut.set_result(msg)
|
||||
w2.network.register_callback(evt_set, ['ln_message'])
|
||||
return fut
|
||||
|
||||
def test_payment(self):
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers()
|
||||
pay_req = self.prepare_invoice(w2)
|
||||
fut = self.prepare_ln_message_future(w2)
|
||||
|
||||
async def pay():
|
||||
addr, peer, coro = LNWorker._pay(w1, pay_req)
|
||||
await coro
|
||||
print("HTLC ADDED")
|
||||
@@ -170,4 +194,28 @@ class TestPeer(unittest.TestCase):
|
||||
gath.cancel()
|
||||
gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop())
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
l.run_until_complete(gath)
|
||||
run(gath)
|
||||
|
||||
def test_channel_usage_after_closing(self):
|
||||
p1, p2, w1, w2, q1, q2 = self.prepare_peers()
|
||||
pay_req = self.prepare_invoice(w2)
|
||||
|
||||
addr = w1._check_invoice(pay_req)
|
||||
route = w1._create_route_from_invoice(decoded_invoice=addr)
|
||||
|
||||
run(p1.force_close_channel(self.alice_channel.channel_id))
|
||||
# check if a tx (commitment transaction) was broadcasted:
|
||||
assert q1.qsize() == 1
|
||||
|
||||
with self.assertRaises(PaymentFailure) as e:
|
||||
w1._create_route_from_invoice(decoded_invoice=addr)
|
||||
self.assertEqual(str(e.exception), 'No path found')
|
||||
|
||||
peer = w1.peers[route[0].node_id]
|
||||
# AssertionError is ok since we shouldn't use old routes, and the
|
||||
# route finding should fail when channel is closed
|
||||
with self.assertRaises(AssertionError):
|
||||
run(asyncio.gather(w1._pay_to_route(route, addr), p1._main_loop(), p2._main_loop()))
|
||||
|
||||
def run(coro):
|
||||
asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
@@ -29,6 +29,7 @@ from electrum import lnchan
|
||||
from electrum import lnutil
|
||||
from electrum import bip32 as bip32_utils
|
||||
from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED
|
||||
from electrum.ecc import sig_string_from_der_sig
|
||||
|
||||
one_bitcoin_in_msat = bitcoin.COIN * 1000
|
||||
|
||||
@@ -81,7 +82,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
|
||||
per_commitment_secret_seed=seed,
|
||||
funding_locked_received=True,
|
||||
was_announced=False,
|
||||
current_commitment_signature=None,
|
||||
# just a random signature
|
||||
current_commitment_signature=sig_string_from_der_sig(bytes.fromhex('3046022100c66e112e22b91b96b795a6dd5f4b004f3acccd9a2a31bf104840f256855b7aa3022100e711b868b62d87c7edd95a2370e496b9cb6a38aff13c9f64f9ff2f3b2a0052dd')),
|
||||
current_htlc_signatures=None,
|
||||
),
|
||||
"constraints":lnbase.ChannelConstraints(
|
||||
@@ -185,6 +187,14 @@ class TestChannel(unittest.TestCase):
|
||||
|
||||
self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0]
|
||||
|
||||
def test_concurrent_reversed_payment(self):
|
||||
self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02')
|
||||
self.htlc_dict['amount_msat'] += 1000
|
||||
bob_idx = self.bob_channel.add_htlc(self.htlc_dict)
|
||||
alice_idx = self.alice_channel.receive_htlc(self.htlc_dict)
|
||||
self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment())
|
||||
self.assertEquals(len(self.alice_channel.pending_remote_commitment.outputs()), 3)
|
||||
|
||||
def test_SimpleAddSettleWorkflow(self):
|
||||
alice_channel, bob_channel = self.alice_channel, self.bob_channel
|
||||
htlc = self.htlc
|
||||
|
||||
Reference in New Issue
Block a user