1
0

lnworker: differentiate PaymentInfo by direction

Allows storing two different payment info of the same payment hash by
including the direction into the db key.
We create and store PaymentInfo for sending attempts and for requests (receiving),
if we try to pay ourself (e.g. through a channel rebalance) the checks
in `save_payment_info` would prevent this and throw an exception.
By storing the PaymentInfos of outgoing and incoming payments separately in
the db this collision is avoided and it makes it easier to reason about
which PaymentInfo belongs where.
This commit is contained in:
f321x
2025-11-28 16:22:22 +01:00
parent 828fc569c9
commit 923d48f9db
12 changed files with 125 additions and 88 deletions

View File

@@ -11,6 +11,7 @@ import shutil
import electrum
from electrum.commands import Commands, eval_bool
from electrum import storage, wallet
from electrum.lnutil import RECEIVED
from electrum.lnworker import RecvMPPResolution
from electrum.wallet import Abstract_Wallet
from electrum.address_synchronizer import TX_HEIGHT_UNCONFIRMED
@@ -509,7 +510,7 @@ class TestCommandsTestnet(ElectrumTestCase):
)
invoice = lndecode(invoice=result['invoice'])
assert invoice.paymenthash.hex() == payment_hash
assert payment_hash in wallet.lnworker.payment_info
assert wallet.lnworker.get_payment_info(bytes.fromhex(payment_hash), direction=RECEIVED)
assert payment_hash in wallet.lnworker.dont_expire_htlcs
assert invoice.get_amount_sat() == 10000
assert invoice.get_description() == "test"
@@ -520,7 +521,7 @@ class TestCommandsTestnet(ElectrumTestCase):
payment_hash=payment_hash,
wallet=wallet,
)
assert payment_hash not in wallet.lnworker.payment_info
assert not wallet.lnworker.get_payment_info(bytes.fromhex(payment_hash), direction=RECEIVED)
assert payment_hash not in wallet.lnworker.dont_expire_htlcs
assert wallet.get_label_for_rhash(rhash=invoice.paymenthash.hex()) == ""
assert cancel_result['cancelled'] == payment_hash

View File

@@ -865,10 +865,10 @@ class TestPeerDirect(TestPeer):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
results = {}
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await w1.pay_invoice(pay_req)
if result is True:
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
results[lnaddr] = PaymentDone()
else:
results[lnaddr] = PaymentFailure()
@@ -988,7 +988,7 @@ class TestPeerDirect(TestPeer):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
async def try_pay_with_too_low_final_cltv_delta(lnaddr, w1=w1, w2=w2):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
assert lnaddr.get_min_final_cltv_delta() == 400 # what the receiver expects
lnaddr.tags = [tag for tag in lnaddr.tags if tag[0] != 'c'] + [['c', 144]]
b11 = lnencode(lnaddr, w2.node_keypair.privkey)
@@ -1079,7 +1079,7 @@ class TestPeerDirect(TestPeer):
result, log = await w1.pay_invoice(pay_req)
assert result is True
# now pay the same invoice again, the payment should be rejected by w2
w1.set_payment_status(pay_req._lnaddr.paymenthash, PR_UNPAID)
w1.set_payment_status(pay_req._lnaddr.paymenthash, PR_UNPAID, direction=lnutil.SENT)
result, log = await w1.pay_invoice(pay_req)
if not result:
# w1.pay_invoice returned a payment failure as the payment got rejected by w2
@@ -1224,8 +1224,8 @@ class TestPeerDirect(TestPeer):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
async def pay():
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash, direction=RECEIVED))
route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route
p1.pay(
@@ -1297,7 +1297,7 @@ class TestPeerDirect(TestPeer):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
async def pay():
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash, direction=RECEIVED))
route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route
p1.pay(
@@ -1997,11 +1997,11 @@ class TestPeerDirect(TestPeer):
w2.dont_settle_htlcs[pay_req.rhash] = None
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3)
if result is True:
self.assertNotIn(pay_req.rhash, w2.dont_settle_htlcs)
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
return PaymentDone()
else:
self.assertIsNone(w2.get_preimage(lnaddr.paymenthash))
@@ -2067,10 +2067,10 @@ class TestPeerDirect(TestPeer):
w2.dont_expire_htlcs[pay_req.rhash] = None if not test_expiry else 20
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await util.wait_for2(w1.pay_invoice(pay_req), timeout=3)
if result is True:
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
return PaymentDone()
else:
self.assertIsNone(w2.get_preimage(lnaddr.paymenthash))
@@ -2210,12 +2210,12 @@ class TestPeerForwarding(TestPeer):
return split_amount_normal(total_amount, num_parts)
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
with mock.patch('electrum.mpp_split.split_amount_normal',
side_effect=mocked_split_amount_normal):
result, log = await graph.workers['bob'].pay_invoice(pay_req)
self.assertTrue(result)
self.assertEqual(PR_PAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, graph.workers['alice'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
async def f():
async with OldTaskGroup() as group:
@@ -2242,10 +2242,10 @@ class TestPeerForwarding(TestPeer):
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['square_graph'])
peers = graph.peers.values()
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await graph.workers['alice'].pay_invoice(pay_req)
self.assertTrue(result)
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
raise PaymentDone()
async def f():
async with OldTaskGroup() as group:
@@ -2309,10 +2309,10 @@ class TestPeerForwarding(TestPeer):
graph.workers['carol'].network.config.TEST_FAIL_HTLCS_WITH_TEMP_NODE_FAILURE = True
peers = graph.peers.values()
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await graph.workers['alice'].pay_invoice(pay_req)
self.assertFalse(result)
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code)
raise PaymentDone()
async def f():
@@ -2336,11 +2336,11 @@ class TestPeerForwarding(TestPeer):
async def pay(lnaddr, pay_req):
self.assertEqual(500000000000, graph.channels[('alice', 'bob')].balance(LOCAL))
self.assertEqual(500000000000, graph.channels[('dave', 'bob')].balance(LOCAL))
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=2)
self.assertEqual(2, len(log))
self.assertTrue(result)
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
self.assertEqual([graph.channels[('alice', 'carol')].short_channel_id, graph.channels[('carol', 'dave')].short_channel_id],
[edge.short_channel_id for edge in log[0].route])
self.assertEqual([graph.channels[('alice', 'bob')].short_channel_id, graph.channels[('bob', 'dave')].short_channel_id],
@@ -2436,11 +2436,11 @@ class TestPeerForwarding(TestPeer):
amount_to_pay = 100_000_000
peers = graph.peers.values()
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=3)
self.assertTrue(result)
self.assertEqual(2, len(log))
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
self.assertEqual(OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, log[0].failure_msg.code)
liquidity_hints = graph.workers['alice'].network.path_finder.liquidity_hints
@@ -2507,14 +2507,14 @@ class TestPeerForwarding(TestPeer):
assert alice_w.network.channel_db is not None
lnaddr, pay_req = self.prepare_invoice(dave_w, include_routing_hints=True, amount_msat=amount_to_pay)
self.prepare_recipient(dave_w, lnaddr.paymenthash, test_hold_invoice, test_failure)
self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await alice_w.pay_invoice(pay_req, attempts=attempts)
if not bob_forwarding:
# reset to previous state, sleep 2s so that the second htlc can time out
graph.workers['bob'].enable_htlc_forwarding = True
await asyncio.sleep(2)
if result:
self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
# check mpp is cleaned up
async with OldTaskGroup() as g:
for peer in peers:
@@ -2642,7 +2642,7 @@ class TestPeerForwarding(TestPeer):
dest_w = graph.workers[destination_name]
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, dest_w.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, dest_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await sender_w.pay_invoice(pay_req, attempts=attempts)
async with OldTaskGroup() as g:
for peer in peers:
@@ -2653,7 +2653,7 @@ class TestPeerForwarding(TestPeer):
for peer in peers:
self.assertEqual(len(peer.lnworker.active_forwardings), 0)
if result:
self.assertEqual(PR_PAID, dest_w.get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, dest_w.get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
raise PaymentDone()
else:
raise NoPathFound()
@@ -2875,7 +2875,7 @@ class TestPeerForwarding(TestPeer):
peers = graph.peers.values()
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_UNPAID, graph.workers['carol'].get_payment_status(lnaddr.paymenthash, direction=RECEIVED))
result, log = await graph.workers['alice'].pay_invoice(pay_req)
self.assertEqual(OnionFailureCode.INVALID_ONION_VERSION, log[0].failure_msg.code)
self.assertFalse(result, msg=log)