1
0

tests: test_lnpeer: test_htlc_switch_iteration_benchmark

Benchmark how long a call to _run_htlc_switch_iteration takes with 10
trampoline mpp sets of 1 htlc each.
This commit is contained in:
f321x
2025-10-09 15:44:52 +02:00
parent f56b13b610
commit 042557da9b
2 changed files with 78 additions and 6 deletions

View File

@@ -14,6 +14,7 @@ from unittest import mock
from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence
from types import MappingProxyType
import time
import statistics
from aiorpcx import timeout_after, TaskTimeout
from electrum_ecc import ECPrivkey
@@ -1886,6 +1887,73 @@ class TestPeerDirect(TestPeer):
for _test_trampoline in [False, True]:
await run_test(_test_trampoline)
async def test_htlc_switch_iteration_benchmark(self):
"""Test how long a call to _run_htlc_switch_iteration takes with 10 trampoline
mpp sets of 1 htlc each. Raise if it takes longer than 20ms (median).
To create flamegraph with py-spy raise NUM_ITERATIONS to 1000 (for more samples) then run:
$ py-spy record -o flamegraph.svg --subprocesses -- python -m pytest tests/test_lnpeer.py::TestPeerDirect::test_htlc_switch_iteration_benchmark
"""
NUM_ITERATIONS = 25
alice_channel, bob_channel = create_test_channels(max_accepted_htlcs=20)
alice_p, bob_p, alice_w, bob_w, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
await self._activate_trampoline(alice_w)
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=bob_w.node_keypair.pubkey),
}
# create 10 invoices (10 pending htlc sets with 1 htlc each)
invoices = [] # type: list[tuple[LnAddr, Invoice]]
for i in range(10):
lnaddr, pay_req = self.prepare_invoice(bob_w)
# prevent bob from settling so that htlc switch will have to iterate through all pending htlcs
bob_w.dont_settle_htlcs[pay_req.rhash] = None
invoices.append((lnaddr, pay_req))
self.assertEqual(len(invoices), 10, msg=len(invoices))
iterations = []
do_benchmark = False
_run_bob_htlc_switch_iteration = bob_p._run_htlc_switch_iteration
def timed_htlc_switch_iteration():
start = time.perf_counter()
_run_bob_htlc_switch_iteration()
duration = time.perf_counter() - start
if do_benchmark:
iterations.append(duration)
bob_p._run_htlc_switch_iteration = timed_htlc_switch_iteration
async def benchmark_htlc_switch_iterations():
waited = 0
while not len(bob_w.received_mpp_htlcs) == 10 :
waited += 0.1
await asyncio.sleep(0.1)
if waited > 2:
raise TimeoutError()
nonlocal do_benchmark
do_benchmark = True
while len(iterations) < NUM_ITERATIONS:
await asyncio.sleep(0.05)
# average = sum(iterations) / len(iterations)
median_duration = statistics.median(iterations)
res = f"median duration per htlc switch iteration: {median_duration:.6f}s over {len(iterations)=}"
self.logger.info(res)
self.assertLess(median_duration, 0.02, msg=res)
raise SuccessfulTest()
async def f():
async with OldTaskGroup() as group:
await group.spawn(alice_p._message_loop())
await group.spawn(alice_p.htlc_switch())
await group.spawn(bob_p._message_loop())
await group.spawn(bob_p.htlc_switch())
await asyncio.sleep(0.01)
for _lnaddr, req in invoices:
await group.spawn(alice_w.pay_invoice(req))
await benchmark_htlc_switch_iterations()
with self.assertRaises(SuccessfulTest):
await f()
class TestPeerForwarding(TestPeer):