Merge pull request #10364 from f321x/test_dont_settle_htlcs_forwarding
lnpeer/lnworker: check dont_settle_htlcs when forwarding
This commit is contained in:
@@ -2306,8 +2306,6 @@ class Peer(Logger, EventListener):
|
||||
local_height = self.network.blockchain().height()
|
||||
payment_hash = htlc_set.get_payment_hash()
|
||||
assert payment_hash is not None, "Empty htlc set?"
|
||||
self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None)
|
||||
self.lnworker.dont_settle_htlcs.pop(payment_hash.hex(), None) # already failed
|
||||
for mpp_htlc in list(htlc_set.htlcs):
|
||||
chan = self.get_channel_by_id(mpp_htlc.channel_id)
|
||||
htlc_id = mpp_htlc.htlc.htlc_id
|
||||
@@ -3230,7 +3228,10 @@ class Peer(Logger, EventListener):
|
||||
# this was a forwarding set and it failed
|
||||
self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.FAILED)
|
||||
return error_bytes or failure_message, None, None
|
||||
preimage = self.lnworker.get_preimage(mpp_set.get_payment_hash())
|
||||
payment_hash = mpp_set.get_payment_hash()
|
||||
if payment_hash.hex() in self.lnworker.dont_settle_htlcs:
|
||||
return None, None, None
|
||||
preimage = self.lnworker.get_preimage(payment_hash)
|
||||
return None, preimage, None
|
||||
|
||||
return None
|
||||
|
||||
@@ -33,7 +33,7 @@ from electrum.util import NetworkRetryManager, bfh, OldTaskGroup, EventListener,
|
||||
from electrum.lnpeer import Peer
|
||||
from electrum.lntransport import LNPeerAddr
|
||||
from electrum.crypto import privkey_to_pubkey
|
||||
from electrum.lnutil import Keypair, PaymentFailure, LnFeatures, HTLCOwner, PaymentFeeBudget
|
||||
from electrum.lnutil import Keypair, PaymentFailure, LnFeatures, HTLCOwner, PaymentFeeBudget, RECEIVED
|
||||
from electrum.lnchannel import ChannelState, PeerState, Channel
|
||||
from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
|
||||
from electrum.channel_db import ChannelDB
|
||||
@@ -41,7 +41,7 @@ from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession
|
||||
from electrum.lnmsg import encode_msg, decode_msg
|
||||
from electrum import lnmsg
|
||||
from electrum.logging import console_stderr_handler, Logger
|
||||
from electrum.lnworker import PaymentInfo, RECEIVED
|
||||
from electrum.lnworker import PaymentInfo
|
||||
from electrum.lnonion import OnionFailureCode, OnionRoutingFailure, OnionHopsDataSingle, OnionPacket
|
||||
from electrum.lnutil import LOCAL, REMOTE, UpdateAddHtlc, RecvMPPResolution
|
||||
from electrum.invoices import PR_PAID, PR_UNPAID, Invoice, LN_EXPIRY_NEVER
|
||||
@@ -2045,76 +2045,6 @@ class TestPeerDirect(TestPeer):
|
||||
with self.assertRaises(SuccessfulTest):
|
||||
await f()
|
||||
|
||||
async def test_dont_settle_htlcs(self):
|
||||
"""
|
||||
Test that htlcs registered in LNWallet.dont_settle_htlcs don't get fulfilled if the preimage is available.
|
||||
"""
|
||||
async def run_test(test_trampoline, test_failure):
|
||||
alice_channel, bob_channel = create_test_channels()
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||
if test_trampoline:
|
||||
await self._activate_trampoline(w1)
|
||||
# declare bob as trampoline node
|
||||
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
|
||||
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
|
||||
}
|
||||
|
||||
preimage = os.urandom(32)
|
||||
lnaddr, pay_req = self.prepare_invoice(
|
||||
w2,
|
||||
payment_preimage=preimage,
|
||||
# use a higher min final cltv delta so we can mine some blocks later
|
||||
min_final_cltv_delta=244,
|
||||
)
|
||||
|
||||
# add payment_hash to dont_settle_htlcs so the htlcs are not getting settled
|
||||
w2.dont_settle_htlcs[pay_req.rhash] = None
|
||||
|
||||
async def pay(lnaddr, pay_req):
|
||||
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
|
||||
result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3)
|
||||
if result is True:
|
||||
self.assertNotIn(pay_req.rhash, w2.dont_settle_htlcs)
|
||||
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
|
||||
return PaymentDone()
|
||||
else:
|
||||
self.assertIsNone(w2.get_preimage(lnaddr.paymenthash))
|
||||
return PaymentFailure()
|
||||
|
||||
async def wait_for_htlcs():
|
||||
payment_key = w2._get_payment_key(lnaddr.paymenthash)
|
||||
while payment_key.hex() not in w2.received_mpp_htlcs:
|
||||
await asyncio.sleep(0.05)
|
||||
w2.network.blockchain()._height += 25 # mine some blocks, shouldn't affect anything
|
||||
if test_failure:
|
||||
# delete preimage, this will fail htlcs even if registered in dont_settle_htlcs
|
||||
del w2._preimages[pay_req.rhash]
|
||||
return # pay() should fail now
|
||||
await asyncio.sleep(0.25) # give w2 some time to do mistakes
|
||||
self.assertEqual(w2.received_mpp_htlcs[payment_key.hex()].resolution, RecvMPPResolution.COMPLETE)
|
||||
# remove the payment hash from dont_settle_htlcs so the htlcs can get fulfilled
|
||||
del w2.dont_settle_htlcs[pay_req.rhash]
|
||||
|
||||
async def f():
|
||||
async with OldTaskGroup() as group:
|
||||
await group.spawn(p1._message_loop())
|
||||
await group.spawn(p1.htlc_switch())
|
||||
await group.spawn(p2._message_loop())
|
||||
await group.spawn(p2.htlc_switch())
|
||||
await asyncio.sleep(0.01)
|
||||
invoice_features = lnaddr.get_features()
|
||||
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
|
||||
pay_task = await group.spawn(pay(lnaddr, pay_req))
|
||||
await util.wait_for2(wait_for_htlcs(), timeout=2)
|
||||
raise await pay_task
|
||||
|
||||
await f()
|
||||
|
||||
for test_trampoline in [False, True]:
|
||||
for test_failure in [False, True]:
|
||||
with self.assertRaises(PaymentFailure if test_failure else PaymentDone):
|
||||
await run_test(test_trampoline, test_failure)
|
||||
|
||||
async def test_dont_expire_htlcs(self):
|
||||
"""
|
||||
Test that htlcs registered in LNWallet.dont_expire_htlcs don't get expired before the
|
||||
@@ -2978,6 +2908,79 @@ class TestPeerForwarding(TestPeer):
|
||||
any('bob->carol' in msg and 'on_update_fail_malformed_htlc' in msg for msg in logs.output)
|
||||
)
|
||||
|
||||
async def test_dont_settle_htlcs_receiver_and_forwarder(self):
|
||||
"""
|
||||
Test that the receiver and forwarder doesn't settle htlcs once they get the preimage if the payment
|
||||
hash is in LNWallet.dont_settle_htlcs. E.g. the forwarder could be a just-in-time channel provider.
|
||||
Alice -> Bob -> Carol. Carol and Bob shouldn't release the preimage.
|
||||
"""
|
||||
async def run_test(test_trampoline):
|
||||
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['line_graph'])
|
||||
peers = graph.peers.values()
|
||||
|
||||
if test_trampoline:
|
||||
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
|
||||
graph.workers['bob'].name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers['bob'].node_keypair.pubkey),
|
||||
}
|
||||
await self._activate_trampoline(graph.workers['carol'])
|
||||
await self._activate_trampoline(graph.workers['alice'])
|
||||
|
||||
lnaddr, pay_req = self.prepare_invoice(graph.workers['carol'], include_routing_hints=True)
|
||||
# test both receiver (carol) and forwarder (bob)
|
||||
graph.workers['bob'].dont_settle_htlcs[lnaddr.paymenthash.hex()] = None
|
||||
graph.workers['carol'].dont_settle_htlcs[lnaddr.paymenthash.hex()] = None
|
||||
|
||||
payment_successful = asyncio.Event()
|
||||
async def pay():
|
||||
self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
|
||||
result, log = await graph.workers['alice'].pay_invoice(pay_req)
|
||||
self.assertEqual(PR_PAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
|
||||
self.assertTrue(result)
|
||||
payment_successful.set()
|
||||
|
||||
async def check_doesnt_settle():
|
||||
while not graph.workers['carol'].received_mpp_htlcs:
|
||||
await asyncio.sleep(0.1) # wait until carol received the htlcs
|
||||
|
||||
await asyncio.sleep(0.2) # give carol time to accidentally release the preimage
|
||||
self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
|
||||
self.assertIsNone(graph.workers['bob'].get_preimage(lnaddr.paymenthash), "bob got preimage from carol")
|
||||
# now allow carol to release the preimage to bob
|
||||
del graph.workers['carol'].dont_settle_htlcs[lnaddr.paymenthash.hex()]
|
||||
|
||||
# wait for carol to release the preimage to bob
|
||||
while not graph.workers['bob'].get_preimage(lnaddr.paymenthash):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# give bob some time to settle the htlcs to alice (this would complete the payment)
|
||||
await asyncio.sleep(0.2)
|
||||
self.assertIsNone(graph.workers['alice'].get_preimage(lnaddr.paymenthash), "alice got preimage from bob")
|
||||
self.assertFalse(payment_successful.is_set(), "bob released preimage")
|
||||
|
||||
# now allow bob to settle the htlcs
|
||||
del graph.workers['bob'].dont_settle_htlcs[lnaddr.paymenthash.hex()]
|
||||
await payment_successful.wait()
|
||||
raise PaymentDone()
|
||||
|
||||
async def f():
|
||||
async with OldTaskGroup() as group:
|
||||
for peer in peers:
|
||||
await group.spawn(peer._message_loop())
|
||||
await group.spawn(peer.htlc_switch())
|
||||
for peer in peers:
|
||||
await peer.initialized
|
||||
|
||||
await group.spawn(pay())
|
||||
await group.spawn(check_doesnt_settle())
|
||||
# stop the taskgroup if anything takes too long
|
||||
await group.spawn(asyncio.wait_for(asyncio.sleep(4), timeout=3))
|
||||
|
||||
await f()
|
||||
|
||||
for trampoline in (False, True):
|
||||
with self.assertRaises(PaymentDone):
|
||||
await run_test(trampoline)
|
||||
|
||||
|
||||
class TestPeerDirectAnchors(TestPeerDirect):
|
||||
TEST_ANCHOR_CHANNELS = True
|
||||
|
||||
Reference in New Issue
Block a user