tests: add tests for "recv mpp confusion" bug
see https://github.com/spesmilo/electrum/security/advisories/GHSA-8r85-vp7r-hjxf
This commit is contained in:
@@ -1194,6 +1194,158 @@ class TestPeer(ElectrumTestCase):
|
||||
with self.assertRaises(PaymentDone):
|
||||
await f()
|
||||
|
||||
async def test_payment_recv_mpp_confusion1(self):
|
||||
"""Regression test for https://github.com/spesmilo/electrum/security/advisories/GHSA-8r85-vp7r-hjxf"""
|
||||
# This test checks that the following attack does not work:
|
||||
# - Bob creates invoice1: 1 BTC, H1, S1
|
||||
# - Bob creates invoice2: 1 BTC, H2, S2; both given to attacker to pay
|
||||
# - Alice sends htlc1: 0.1 BTC, H1, S1 (total_msat=1 BTC)
|
||||
# - Alice sends htlc2: 0.9 BTC, H2, S1 (total_msat=1 BTC)
|
||||
# - Bob(victim) reveals preimage for H1 and fulfills htlc1 (fails other)
|
||||
alice_channel, bob_channel = create_test_channels()
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||
async def pay():
|
||||
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash))
|
||||
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash))
|
||||
|
||||
route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route
|
||||
p1.pay(
|
||||
route=route,
|
||||
chan=alice_channel,
|
||||
amount_msat=1000,
|
||||
total_msat=lnaddr1.get_amount_msat(),
|
||||
payment_hash=lnaddr1.paymenthash,
|
||||
min_final_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(),
|
||||
payment_secret=lnaddr1.payment_secret,
|
||||
)
|
||||
p1.pay(
|
||||
route=route,
|
||||
chan=alice_channel,
|
||||
amount_msat=lnaddr1.get_amount_msat() - 1000,
|
||||
total_msat=lnaddr1.get_amount_msat(),
|
||||
payment_hash=lnaddr2.paymenthash,
|
||||
min_final_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(),
|
||||
payment_secret=lnaddr1.payment_secret,
|
||||
)
|
||||
|
||||
while nhtlc_success + nhtlc_failed < 2:
|
||||
await htlc_resolved.wait()
|
||||
self.assertEqual(0, nhtlc_success)
|
||||
self.assertEqual(2, nhtlc_failed)
|
||||
raise SuccessfulTest()
|
||||
|
||||
w2.features |= LnFeatures.BASIC_MPP_OPT
|
||||
lnaddr1, _pay_req = self.prepare_invoice(w2, amount_msat=100_000_000)
|
||||
lnaddr2, _pay_req = self.prepare_invoice(w2, amount_msat=100_000_000)
|
||||
self.assertTrue(lnaddr1.get_features().supports(LnFeatures.BASIC_MPP_OPT))
|
||||
self.assertTrue(lnaddr2.get_features().supports(LnFeatures.BASIC_MPP_OPT))
|
||||
|
||||
async def f():
|
||||
async with OldTaskGroup() as group:
|
||||
await group.spawn(p1._message_loop())
|
||||
await group.spawn(p1.htlc_switch())
|
||||
await group.spawn(p2._message_loop())
|
||||
await group.spawn(p2.htlc_switch())
|
||||
await asyncio.sleep(0.01)
|
||||
await group.spawn(pay())
|
||||
|
||||
htlc_resolved = asyncio.Event()
|
||||
nhtlc_success = 0
|
||||
nhtlc_failed = 0
|
||||
async def on_htlc_fulfilled(*args):
|
||||
htlc_resolved.set()
|
||||
htlc_resolved.clear()
|
||||
nonlocal nhtlc_success
|
||||
nhtlc_success += 1
|
||||
async def on_htlc_failed(*args):
|
||||
htlc_resolved.set()
|
||||
htlc_resolved.clear()
|
||||
nonlocal nhtlc_failed
|
||||
nhtlc_failed += 1
|
||||
util.register_callback(on_htlc_fulfilled, ["htlc_fulfilled"])
|
||||
util.register_callback(on_htlc_failed, ["htlc_failed"])
|
||||
|
||||
try:
|
||||
with self.assertRaises(SuccessfulTest):
|
||||
await f()
|
||||
finally:
|
||||
util.unregister_callback(on_htlc_fulfilled)
|
||||
util.unregister_callback(on_htlc_failed)
|
||||
|
||||
async def test_payment_recv_mpp_confusion2(self):
|
||||
"""Regression test for https://github.com/spesmilo/electrum/security/advisories/GHSA-8r85-vp7r-hjxf"""
|
||||
# This test checks that the following attack does not work:
|
||||
# - Bob creates invoice: 1 BTC
|
||||
# - Alice sends htlc1: 0.1 BTC (total_msat=0.2 BTC)
|
||||
# - Alice sends htlc2: 0.1 BTC (total_msat=1 BTC)
|
||||
# - Bob(victim) reveals preimage and fulfills htlc2 (fails other)
|
||||
alice_channel, bob_channel = create_test_channels()
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||
async def pay():
|
||||
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash))
|
||||
|
||||
route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route
|
||||
p1.pay(
|
||||
route=route,
|
||||
chan=alice_channel,
|
||||
amount_msat=1000,
|
||||
total_msat=2000,
|
||||
payment_hash=lnaddr1.paymenthash,
|
||||
min_final_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(),
|
||||
payment_secret=lnaddr1.payment_secret,
|
||||
)
|
||||
p1.pay(
|
||||
route=route,
|
||||
chan=alice_channel,
|
||||
amount_msat=1000,
|
||||
total_msat=lnaddr1.get_amount_msat(),
|
||||
payment_hash=lnaddr1.paymenthash,
|
||||
min_final_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(),
|
||||
payment_secret=lnaddr1.payment_secret,
|
||||
)
|
||||
|
||||
while nhtlc_success + nhtlc_failed < 2:
|
||||
await htlc_resolved.wait()
|
||||
self.assertEqual(0, nhtlc_success)
|
||||
self.assertEqual(2, nhtlc_failed)
|
||||
raise SuccessfulTest()
|
||||
|
||||
w2.features |= LnFeatures.BASIC_MPP_OPT
|
||||
lnaddr1, _pay_req = self.prepare_invoice(w2, amount_msat=100_000_000)
|
||||
self.assertTrue(lnaddr1.get_features().supports(LnFeatures.BASIC_MPP_OPT))
|
||||
|
||||
async def f():
|
||||
async with OldTaskGroup() as group:
|
||||
await group.spawn(p1._message_loop())
|
||||
await group.spawn(p1.htlc_switch())
|
||||
await group.spawn(p2._message_loop())
|
||||
await group.spawn(p2.htlc_switch())
|
||||
await asyncio.sleep(0.01)
|
||||
await group.spawn(pay())
|
||||
|
||||
htlc_resolved = asyncio.Event()
|
||||
nhtlc_success = 0
|
||||
nhtlc_failed = 0
|
||||
async def on_htlc_fulfilled(*args):
|
||||
htlc_resolved.set()
|
||||
htlc_resolved.clear()
|
||||
nonlocal nhtlc_success
|
||||
nhtlc_success += 1
|
||||
async def on_htlc_failed(*args):
|
||||
htlc_resolved.set()
|
||||
htlc_resolved.clear()
|
||||
nonlocal nhtlc_failed
|
||||
nhtlc_failed += 1
|
||||
util.register_callback(on_htlc_fulfilled, ["htlc_fulfilled"])
|
||||
util.register_callback(on_htlc_failed, ["htlc_failed"])
|
||||
|
||||
try:
|
||||
with self.assertRaises(SuccessfulTest):
|
||||
await f()
|
||||
finally:
|
||||
util.unregister_callback(on_htlc_fulfilled)
|
||||
util.unregister_callback(on_htlc_failed)
|
||||
|
||||
async def _run_mpp(self, graph, fail_kwargs, success_kwargs):
|
||||
"""Tests a multipart payment scenario for failing and successful cases."""
|
||||
self.assertEqual(500_000_000_000, graph.channels[('alice', 'bob')].balance(LOCAL))
|
||||
|
||||
Reference in New Issue
Block a user