1
0

Merge pull request #10364 from f321x/test_dont_settle_htlcs_forwarding

lnpeer/lnworker: check dont_settle_htlcs when forwarding
This commit is contained in:
ghost43
2025-12-30 16:13:16 +00:00
committed by GitHub
2 changed files with 79 additions and 75 deletions

View File

@@ -2306,8 +2306,6 @@ class Peer(Logger, EventListener):
local_height = self.network.blockchain().height()
payment_hash = htlc_set.get_payment_hash()
assert payment_hash is not None, "Empty htlc set?"
self.lnworker.dont_expire_htlcs.pop(payment_hash.hex(), None)
self.lnworker.dont_settle_htlcs.pop(payment_hash.hex(), None) # already failed
for mpp_htlc in list(htlc_set.htlcs):
chan = self.get_channel_by_id(mpp_htlc.channel_id)
htlc_id = mpp_htlc.htlc.htlc_id
@@ -3230,7 +3228,10 @@ class Peer(Logger, EventListener):
# this was a forwarding set and it failed
self.lnworker.set_mpp_resolution(payment_key, RecvMPPResolution.FAILED)
return error_bytes or failure_message, None, None
preimage = self.lnworker.get_preimage(mpp_set.get_payment_hash())
payment_hash = mpp_set.get_payment_hash()
if payment_hash.hex() in self.lnworker.dont_settle_htlcs:
return None, None, None
preimage = self.lnworker.get_preimage(payment_hash)
return None, preimage, None
return None

View File

@@ -33,7 +33,7 @@ from electrum.util import NetworkRetryManager, bfh, OldTaskGroup, EventListener,
from electrum.lnpeer import Peer
from electrum.lntransport import LNPeerAddr
from electrum.crypto import privkey_to_pubkey
from electrum.lnutil import Keypair, PaymentFailure, LnFeatures, HTLCOwner, PaymentFeeBudget
from electrum.lnutil import Keypair, PaymentFailure, LnFeatures, HTLCOwner, PaymentFeeBudget, RECEIVED
from electrum.lnchannel import ChannelState, PeerState, Channel
from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
from electrum.channel_db import ChannelDB
@@ -41,7 +41,7 @@ from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession
from electrum.lnmsg import encode_msg, decode_msg
from electrum import lnmsg
from electrum.logging import console_stderr_handler, Logger
from electrum.lnworker import PaymentInfo, RECEIVED
from electrum.lnworker import PaymentInfo
from electrum.lnonion import OnionFailureCode, OnionRoutingFailure, OnionHopsDataSingle, OnionPacket
from electrum.lnutil import LOCAL, REMOTE, UpdateAddHtlc, RecvMPPResolution
from electrum.invoices import PR_PAID, PR_UNPAID, Invoice, LN_EXPIRY_NEVER
@@ -2045,76 +2045,6 @@ class TestPeerDirect(TestPeer):
with self.assertRaises(SuccessfulTest):
await f()
async def test_dont_settle_htlcs(self):
"""
Test that htlcs registered in LNWallet.dont_settle_htlcs don't get fulfilled if the preimage is available.
"""
async def run_test(test_trampoline, test_failure):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
if test_trampoline:
await self._activate_trampoline(w1)
# declare bob as trampoline node
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
}
preimage = os.urandom(32)
lnaddr, pay_req = self.prepare_invoice(
w2,
payment_preimage=preimage,
# use a higher min final cltv delta so we can mine some blocks later
min_final_cltv_delta=244,
)
# add payment_hash to dont_settle_htlcs so the htlcs are not getting settled
w2.dont_settle_htlcs[pay_req.rhash] = None
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3)
if result is True:
self.assertNotIn(pay_req.rhash, w2.dont_settle_htlcs)
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
return PaymentDone()
else:
self.assertIsNone(w2.get_preimage(lnaddr.paymenthash))
return PaymentFailure()
async def wait_for_htlcs():
payment_key = w2._get_payment_key(lnaddr.paymenthash)
while payment_key.hex() not in w2.received_mpp_htlcs:
await asyncio.sleep(0.05)
w2.network.blockchain()._height += 25 # mine some blocks, shouldn't affect anything
if test_failure:
# delete preimage, this will fail htlcs even if registered in dont_settle_htlcs
del w2._preimages[pay_req.rhash]
return # pay() should fail now
await asyncio.sleep(0.25) # give w2 some time to do mistakes
self.assertEqual(w2.received_mpp_htlcs[payment_key.hex()].resolution, RecvMPPResolution.COMPLETE)
# remove the payment hash from dont_settle_htlcs so the htlcs can get fulfilled
del w2.dont_settle_htlcs[pay_req.rhash]
async def f():
async with OldTaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p1.htlc_switch())
await group.spawn(p2._message_loop())
await group.spawn(p2.htlc_switch())
await asyncio.sleep(0.01)
invoice_features = lnaddr.get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
pay_task = await group.spawn(pay(lnaddr, pay_req))
await util.wait_for2(wait_for_htlcs(), timeout=2)
raise await pay_task
await f()
for test_trampoline in [False, True]:
for test_failure in [False, True]:
with self.assertRaises(PaymentFailure if test_failure else PaymentDone):
await run_test(test_trampoline, test_failure)
async def test_dont_expire_htlcs(self):
"""
Test that htlcs registered in LNWallet.dont_expire_htlcs don't get expired before the
@@ -2978,6 +2908,79 @@ class TestPeerForwarding(TestPeer):
any('bob->carol' in msg and 'on_update_fail_malformed_htlc' in msg for msg in logs.output)
)
async def test_dont_settle_htlcs_receiver_and_forwarder(self):
"""
Test that the receiver and forwarder doesn't settle htlcs once they get the preimage if the payment
hash is in LNWallet.dont_settle_htlcs. E.g. the forwarder could be a just-in-time channel provider.
Alice -> Bob -> Carol. Carol and Bob shouldn't release the preimage.
"""
async def run_test(test_trampoline):
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['line_graph'])
peers = graph.peers.values()
if test_trampoline:
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
graph.workers['bob'].name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers['bob'].node_keypair.pubkey),
}
await self._activate_trampoline(graph.workers['carol'])
await self._activate_trampoline(graph.workers['alice'])
lnaddr, pay_req = self.prepare_invoice(graph.workers['carol'], include_routing_hints=True)
# test both receiver (carol) and forwarder (bob)
graph.workers['bob'].dont_settle_htlcs[lnaddr.paymenthash.hex()] = None
graph.workers['carol'].dont_settle_htlcs[lnaddr.paymenthash.hex()] = None
payment_successful = asyncio.Event()
async def pay():
self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await graph.workers['alice'].pay_invoice(pay_req)
self.assertEqual(PR_PAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
self.assertTrue(result)
payment_successful.set()
async def check_doesnt_settle():
while not graph.workers['carol'].received_mpp_htlcs:
await asyncio.sleep(0.1) # wait until carol received the htlcs
await asyncio.sleep(0.2) # give carol time to accidentally release the preimage
self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
self.assertIsNone(graph.workers['bob'].get_preimage(lnaddr.paymenthash), "bob got preimage from carol")
# now allow carol to release the preimage to bob
del graph.workers['carol'].dont_settle_htlcs[lnaddr.paymenthash.hex()]
# wait for carol to release the preimage to bob
while not graph.workers['bob'].get_preimage(lnaddr.paymenthash):
await asyncio.sleep(0.1)
# give bob some time to settle the htlcs to alice (this would complete the payment)
await asyncio.sleep(0.2)
self.assertIsNone(graph.workers['alice'].get_preimage(lnaddr.paymenthash), "alice got preimage from bob")
self.assertFalse(payment_successful.is_set(), "bob released preimage")
# now allow bob to settle the htlcs
del graph.workers['bob'].dont_settle_htlcs[lnaddr.paymenthash.hex()]
await payment_successful.wait()
raise PaymentDone()
async def f():
async with OldTaskGroup() as group:
for peer in peers:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
for peer in peers:
await peer.initialized
await group.spawn(pay())
await group.spawn(check_doesnt_settle())
# stop the taskgroup if anything takes too long
await group.spawn(asyncio.wait_for(asyncio.sleep(4), timeout=3))
await f()
for trampoline in (False, True):
with self.assertRaises(PaymentDone):
await run_test(trampoline)
class TestPeerDirectAnchors(TestPeerDirect):
TEST_ANCHOR_CHANNELS = True